File size: 4,909 Bytes
7e50900 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#include <ATen/ATen.h>
#include <cstdint>
#ifdef USE_FBGEMM
#include <fbgemm/FbgemmEmbedding.h>
#endif
namespace at {
namespace native {
void check_arguments(
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const c10::optional<Tensor>& per_sample_weights,
bool include_last_offset);
void make_bag_size_out(
Tensor& bag_size_out,
const Tensor& offsets,
const Tensor& indices,
const int64_t mode,
const bool include_last_offset,
const bool requires_grad);
void make_max_indices_out(
Tensor& max_indices_out,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const Tensor& bag_size,
const int64_t mode,
bool include_last_offset);
void make_offset2bag_out(
Tensor& offset2bag,
Tensor& output,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const c10::optional<Tensor>& per_sample_weights,
const int64_t padding_idx = -1);
#ifdef USE_FBGEMM
template<bool has_weight, typename TIndex, typename TData>
struct _CallbackAndBlockSize {
using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
int64_t blockSize = -1;
TCallback callback = nullptr;
static TCallback generateCallback(int64_t block_size) {
return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
block_size,
has_weight,
/* normalize_by_lengths */false,
/* prefetch */16,
/* is_weight_positional */false,
/* use_offsets */true);
}
_CallbackAndBlockSize() = default;
explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size)
: blockSize(maybe_block_size.value_or(-1))
, callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
{}
};
template<typename... StorageMixins>
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
_EmbeddingBagKernelCacheImpl() = default;
// use each of the mixins to store corresponding kernel and block size
explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size)
: StorageMixins(maybe_block_size)...
{}
// this method is thread safe (call sites may call from different threads)
template<bool has_weight, typename TIndex, typename TData>
typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
getCallback(int64_t block_size) const {
// if the cache doesn't store the kernel for the incoming block size
// (so it is different from the one stored in corresponding mixin)
// regenerate the kernel (not writing it into the cache so we avoid locks)
if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
}
// else retrieve the cached kernel from the corresponding mixin
return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
}
};
// instantiate the cache with the list of storage mixins
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
_CallbackAndBlockSize<true, int32_t, float>,
_CallbackAndBlockSize<false, int32_t, float>,
_CallbackAndBlockSize<true, int64_t, float>,
_CallbackAndBlockSize<false, int64_t, float>,
_CallbackAndBlockSize<true, int32_t, unsigned short>,
_CallbackAndBlockSize<false, int32_t, unsigned short>,
_CallbackAndBlockSize<true, int64_t, unsigned short>,
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
#else
struct _EmbeddingBagKernelCache {
explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
};
#endif
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
Tensor& bag_size, Tensor* max_indices,
const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const int64_t mode = 0,
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
bool include_last_offset = false,
int64_t padding_idx = -1,
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
void _embedding_bag_cpu_out(
at::Tensor& output,
at::Tensor& offset2bag,
at::Tensor& bag_size,
at::Tensor* p_max_indices,
const at::Tensor& weight,
const at::Tensor& indices,
const at::Tensor& offsets,
const bool scale_grad_by_freq,
const int64_t mode,
const bool sparse,
const c10::optional<at::Tensor>& per_sample_weights,
const bool include_last_offset,
const c10::optional<int64_t>& padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
} // native
} // at
|