#include #include #ifdef USE_FBGEMM #include #endif namespace at { namespace native { void check_arguments( const Tensor& weight, const Tensor& indices, const Tensor& offsets, const int64_t mode, const c10::optional& per_sample_weights, bool include_last_offset); void make_bag_size_out( Tensor& bag_size_out, const Tensor& offsets, const Tensor& indices, const int64_t mode, const bool include_last_offset, const bool requires_grad); void make_max_indices_out( Tensor& max_indices_out, const Tensor& weight, const Tensor& indices, const Tensor& offsets, const Tensor& bag_size, const int64_t mode, bool include_last_offset); void make_offset2bag_out( Tensor& offset2bag, Tensor& output, const Tensor& weight, const Tensor& indices, const Tensor& offsets, const int64_t mode, const c10::optional& per_sample_weights, const int64_t padding_idx = -1); #ifdef USE_FBGEMM template struct _CallbackAndBlockSize { using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature::Type; int64_t blockSize = -1; TCallback callback = nullptr; static TCallback generateCallback(int64_t block_size) { return fbgemm::GenerateEmbeddingSpMDM( block_size, has_weight, /* normalize_by_lengths */false, /* prefetch */16, /* is_weight_positional */false, /* use_offsets */true); } _CallbackAndBlockSize() = default; explicit _CallbackAndBlockSize(c10::optional maybe_block_size) : blockSize(maybe_block_size.value_or(-1)) , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr) {} }; template struct _EmbeddingBagKernelCacheImpl : private StorageMixins... { _EmbeddingBagKernelCacheImpl() = default; // use each of the mixins to store corresponding kernel and block size explicit _EmbeddingBagKernelCacheImpl(c10::optional maybe_block_size) : StorageMixins(maybe_block_size)... {} // this method is thread safe (call sites may call from different threads) template typename _CallbackAndBlockSize::TCallback getCallback(int64_t block_size) const { // if the cache doesn't store the kernel for the incoming block size // (so it is different from the one stored in corresponding mixin) // regenerate the kernel (not writing it into the cache so we avoid locks) if (block_size != _CallbackAndBlockSize::blockSize) { return _CallbackAndBlockSize::generateCallback(block_size); } // else retrieve the cached kernel from the corresponding mixin return _CallbackAndBlockSize::callback; } }; // instantiate the cache with the list of storage mixins // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl< _CallbackAndBlockSize, _CallbackAndBlockSize, _CallbackAndBlockSize, _CallbackAndBlockSize, _CallbackAndBlockSize, _CallbackAndBlockSize, _CallbackAndBlockSize, _CallbackAndBlockSize>; #else struct _EmbeddingBagKernelCache { explicit _EmbeddingBagKernelCache(c10::optional /* maybe_block_size */) {} }; #endif void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, Tensor& bag_size, Tensor* max_indices, const Tensor &weight, const Tensor &indices, const Tensor &offsets, const int64_t mode = 0, const c10::optional& per_sample_weights = c10::nullopt, bool include_last_offset = false, int64_t padding_idx = -1, _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); void _embedding_bag_cpu_out( at::Tensor& output, at::Tensor& offset2bag, at::Tensor& bag_size, at::Tensor* p_max_indices, const at::Tensor& weight, const at::Tensor& indices, const at::Tensor& offsets, const bool scale_grad_by_freq, const int64_t mode, const bool sparse, const c10::optional& per_sample_weights, const bool include_last_offset, const c10::optional& padding_idx, _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); } // native } // at