#pragma once #include namespace at { class Tensor; class TensorBase; struct TensorIterator; struct TensorIteratorBase; } namespace c10 { class Scalar; } namespace at { namespace native { using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source); using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride); using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate); using take_fn = void(*)(TensorIterator & iter, const TensorBase& input); using flip_fn = void(*)(TensorIterator &, const bool); using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar); using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &); DECLARE_DISPATCH(index_fn, index_stub); DECLARE_DISPATCH(index_fill_fn, index_fill_stub); DECLARE_DISPATCH(index_copy_fn, index_copy_stub); DECLARE_DISPATCH(index_put_fn, index_put_stub); DECLARE_DISPATCH(put_fn, put_stub); DECLARE_DISPATCH(take_fn, take_stub); DECLARE_DISPATCH(flip_fn, flip_stub); DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub); DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub); DECLARE_DISPATCH(masked_select_fn, masked_select_stub); DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub); }} // namespace at::native