|
#pragma once |
|
|
|
#include <ATen/Config.h> |
|
#include <ATen/Parallel.h> |
|
#include <ATen/OpMathType.h> |
|
#include <ATen/cpu/vec/functional.h> |
|
#include <ATen/cpu/vec/vec.h> |
|
#include <c10/util/complex.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <algorithm> |
|
#include <cstddef> |
|
#include <cstdint> |
|
#include <cstring> |
|
#include <type_traits> |
|
|
|
#if AT_MKL_ENABLED() && !defined(__APPLE__) |
|
#include <mkl.h> |
|
#endif |
|
|
|
namespace at { |
|
namespace vml { |
|
inline namespace CPU_CAPABILITY { |
|
|
|
using namespace vec; |
|
|
|
template <typename scalar_t> |
|
inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) { |
|
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { |
|
map( |
|
[](const Vectorized<scalar_t>& x) { |
|
return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt(); |
|
}, |
|
out + begin, |
|
in + begin, |
|
end - begin); |
|
}); |
|
} |
|
|
|
|
|
|
|
#define IMPLEMENT_VML(op) \ |
|
template <typename scalar_t> \ |
|
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \ |
|
using vec_t = Vectorized<vec_scalar_t<scalar_t>>; \ |
|
vec::map([](vec_t x) { return x.op(); }, out, in, size); \ |
|
} \ |
|
|
|
IMPLEMENT_VML(abs) |
|
IMPLEMENT_VML(acos) |
|
IMPLEMENT_VML(asin) |
|
IMPLEMENT_VML(atan) |
|
IMPLEMENT_VML(atanh) |
|
IMPLEMENT_VML(ceil) |
|
IMPLEMENT_VML(cos) |
|
|
|
IMPLEMENT_VML(erf) |
|
IMPLEMENT_VML(erfc) |
|
IMPLEMENT_VML(erfinv) |
|
IMPLEMENT_VML(exp) |
|
IMPLEMENT_VML(expm1) |
|
IMPLEMENT_VML(floor) |
|
IMPLEMENT_VML(i0) |
|
IMPLEMENT_VML(i0e) |
|
IMPLEMENT_VML(digamma) |
|
IMPLEMENT_VML(reciprocal) |
|
IMPLEMENT_VML(log) |
|
IMPLEMENT_VML(log10) |
|
IMPLEMENT_VML(log1p) |
|
IMPLEMENT_VML(log2) |
|
IMPLEMENT_VML(neg) |
|
IMPLEMENT_VML(sin) |
|
|
|
IMPLEMENT_VML(sqrt) |
|
IMPLEMENT_VML(round) |
|
IMPLEMENT_VML(rsqrt) |
|
IMPLEMENT_VML(tan) |
|
IMPLEMENT_VML(tanh) |
|
IMPLEMENT_VML(trunc) |
|
IMPLEMENT_VML(lgamma) |
|
|
|
|
|
#if AT_MKL_ENABLED() && !defined(__APPLE__) |
|
|
|
|
|
|
|
|
|
static_assert( |
|
std::is_same_v<MKL_INT, int32_t> || std::is_same_v<MKL_INT, int64_t>, |
|
"MKL_INT is assumed to be int32_t or int64_t"); |
|
#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \ |
|
template <> \ |
|
inline void v##op(type * out, const type * in, int64_t size) { \ |
|
int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \ |
|
if (size <= static_cast<int64_t>(max_mkl_ind)) { \ |
|
vm##mkltype##mklop( \ |
|
size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ |
|
} else { \ |
|
MKL_INT ind = 0; \ |
|
int64_t chunks = size / max_mkl_ind; \ |
|
int64_t rest = size % max_mkl_ind; \ |
|
for (; ind < chunks; ind++) { \ |
|
vm##mkltype##mklop( \ |
|
max_mkl_ind, \ |
|
in + ind * max_mkl_ind, \ |
|
out + ind * max_mkl_ind, \ |
|
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ |
|
} \ |
|
vm##mkltype##mklop( \ |
|
rest, \ |
|
in + ind * max_mkl_ind, \ |
|
out + ind * max_mkl_ind, \ |
|
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ |
|
} \ |
|
} |
|
|
|
#define IMPLEMENT_VML_MKL(op, mklop) \ |
|
IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \ |
|
IMPLEMENT_VML_MKL_STUB(op, mklop, double, d) |
|
|
|
|
|
|
|
IMPLEMENT_VML_MKL(acos, Acos) |
|
IMPLEMENT_VML_MKL(asin, Asin) |
|
IMPLEMENT_VML_MKL(atan, Atan) |
|
IMPLEMENT_VML_MKL(cos, Cos) |
|
|
|
IMPLEMENT_VML_MKL(erf, Erf) |
|
IMPLEMENT_VML_MKL(erfc, Erfc) |
|
IMPLEMENT_VML_MKL(erfinv, ErfInv) |
|
IMPLEMENT_VML_MKL(exp, Exp) |
|
|
|
IMPLEMENT_VML_MKL(log, Ln) |
|
IMPLEMENT_VML_MKL(log10, Log10) |
|
IMPLEMENT_VML_MKL(sin, Sin) |
|
|
|
IMPLEMENT_VML_MKL(sqrt, Sqrt) |
|
IMPLEMENT_VML_MKL(tan, Tan) |
|
IMPLEMENT_VML_MKL(tanh, Tanh) |
|
IMPLEMENT_VML_MKL(trunc, Trunc) |
|
|
|
|
|
|
|
|
|
|
|
#if INTEL_MKL_VERSION >= 20180406 |
|
IMPLEMENT_VML_MKL(log2, Log2) |
|
#endif |
|
|
|
#endif |
|
|
|
} |
|
} |
|
} |
|
|