|
#pragma once |
|
|
|
#include <c10/core/Backend.h> |
|
#include <c10/core/ScalarType.h> |
|
#include <c10/util/Exception.h> |
|
|
|
#include <type_traits> |
|
#include <atomic> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(__clang__) |
|
#pragma clang diagnostic push |
|
#pragma clang diagnostic ignored "-Wundefined-var-template" |
|
#endif |
|
|
|
namespace at { namespace native { |
|
|
|
enum class CPUCapability { |
|
DEFAULT = 0, |
|
#if defined(HAVE_VSX_CPU_DEFINITION) |
|
VSX = 1, |
|
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION) |
|
ZVECTOR = 1, |
|
#else |
|
AVX2 = 1, |
|
AVX512 = 2, |
|
#endif |
|
NUM_OPTIONS |
|
}; |
|
|
|
CPUCapability get_cpu_capability(); |
|
|
|
template <typename FnPtr, typename T> |
|
struct DispatchStub; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API DispatchStubImpl { |
|
void* get_call_ptr( |
|
DeviceType device_type |
|
, void *DEFAULT |
|
#ifdef HAVE_AVX512_CPU_DEFINITION |
|
, void *AVX512 |
|
#endif |
|
#ifdef HAVE_AVX2_CPU_DEFINITION |
|
, void *AVX2 |
|
#endif |
|
#ifdef HAVE_VSX_CPU_DEFINITION |
|
, void *VSX |
|
#endif |
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION |
|
, void *ZVECTOR |
|
#endif |
|
); |
|
|
|
|
|
|
|
|
|
|
|
|
|
void* choose_cpu_impl( |
|
void *DEFAULT |
|
#ifdef HAVE_AVX512_CPU_DEFINITION |
|
, void *AVX512 |
|
#endif |
|
#ifdef HAVE_AVX2_CPU_DEFINITION |
|
, void *AVX2 |
|
#endif |
|
#ifdef HAVE_VSX_CPU_DEFINITION |
|
, void *VSX |
|
#endif |
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION |
|
, void *ZVECTOR |
|
#endif |
|
); |
|
|
|
|
|
|
|
#if defined(_MSC_VER) && defined(_DEBUG) |
|
std::atomic<void*> cpu_dispatch_ptr; |
|
void* cuda_dispatch_ptr; |
|
void* hip_dispatch_ptr; |
|
void* mps_dispatch_ptr; |
|
#else |
|
std::atomic<void*> cpu_dispatch_ptr{nullptr}; |
|
void* cuda_dispatch_ptr = nullptr; |
|
void* hip_dispatch_ptr = nullptr; |
|
void* mps_dispatch_ptr = nullptr; |
|
#endif |
|
}; |
|
|
|
template <typename rT, typename T, typename... Args> |
|
struct DispatchStub<rT (*)(Args...), T> { |
|
using FnPtr = rT (*) (Args...); |
|
|
|
DispatchStub() = default; |
|
DispatchStub(const DispatchStub&) = delete; |
|
DispatchStub& operator=(const DispatchStub&) = delete; |
|
|
|
private: |
|
FnPtr get_call_ptr(DeviceType device_type) { |
|
return reinterpret_cast<FnPtr>( |
|
impl.get_call_ptr(device_type |
|
, reinterpret_cast<void*>(DEFAULT) |
|
#ifdef HAVE_AVX512_CPU_DEFINITION |
|
, reinterpret_cast<void*>(AVX512) |
|
#endif |
|
#ifdef HAVE_AVX2_CPU_DEFINITION |
|
, reinterpret_cast<void*>(AVX2) |
|
#endif |
|
#ifdef HAVE_VSX_CPU_DEFINITION |
|
, reinterpret_cast<void*>(VSX) |
|
#endif |
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION |
|
, reinterpret_cast<void*>(ZVECTOR) |
|
#endif |
|
) |
|
); |
|
} |
|
|
|
public: |
|
template <typename... ArgTypes> |
|
rT operator()(DeviceType device_type, ArgTypes&&... args) { |
|
FnPtr call_ptr = get_call_ptr(device_type); |
|
return (*call_ptr)(std::forward<ArgTypes>(args)...); |
|
} |
|
|
|
void set_cuda_dispatch_ptr(FnPtr fn_ptr) { |
|
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
|
} |
|
|
|
void set_hip_dispatch_ptr(FnPtr fn_ptr) { |
|
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
|
} |
|
|
|
void set_mps_dispatch_ptr(FnPtr fn_ptr) { |
|
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
|
} |
|
|
|
static TORCH_API FnPtr DEFAULT; |
|
#ifdef HAVE_AVX512_CPU_DEFINITION |
|
static TORCH_API FnPtr AVX512; |
|
#endif |
|
#ifdef HAVE_AVX2_CPU_DEFINITION |
|
static TORCH_API FnPtr AVX2; |
|
#endif |
|
#ifdef HAVE_VSX_CPU_DEFINITION |
|
static TORCH_API FnPtr VSX; |
|
#endif |
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION |
|
static TORCH_API FnPtr ZVECTOR; |
|
#endif |
|
private: |
|
DispatchStubImpl impl; |
|
}; |
|
|
|
namespace { |
|
template <typename DispatchStub> |
|
struct RegisterCUDADispatch { |
|
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
|
stub.set_cuda_dispatch_ptr(value); |
|
} |
|
}; |
|
|
|
template <typename DispatchStub> |
|
struct RegisterMPSDispatch { |
|
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
|
stub.set_mps_dispatch_ptr(value); |
|
} |
|
}; |
|
|
|
template <typename DispatchStub> |
|
struct RegisterHIPDispatch { |
|
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
|
|
|
stub.set_cuda_dispatch_ptr(value); |
|
} |
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define DECLARE_DISPATCH(fn, name) \ |
|
struct name : DispatchStub<fn, name> { \ |
|
name() = default; \ |
|
name(const name&) = delete; \ |
|
name& operator=(const name&) = delete; \ |
|
}; \ |
|
extern TORCH_API struct name name |
|
|
|
#define DEFINE_DISPATCH(name) struct name name |
|
|
|
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \ |
|
template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn; |
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION |
|
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) |
|
#else |
|
#define REGISTER_AVX512_DISPATCH(name, fn) |
|
#endif |
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION |
|
#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) |
|
#else |
|
#define REGISTER_AVX2_DISPATCH(name, fn) |
|
#endif |
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION |
|
#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) |
|
#else |
|
#define REGISTER_VSX_DISPATCH(name, fn) |
|
#endif |
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION |
|
#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn) |
|
#else |
|
#define REGISTER_ZVECTOR_DISPATCH(name, fn) |
|
#endif |
|
|
|
|
|
|
|
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \ |
|
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ |
|
REGISTER_AVX512_DISPATCH(name, fn) \ |
|
REGISTER_AVX2_DISPATCH(name, fn) \ |
|
REGISTER_VSX_DISPATCH(name, fn) \ |
|
REGISTER_ZVECTOR_DISPATCH(name, fn) |
|
|
|
#define REGISTER_NO_CPU_DISPATCH(name) \ |
|
REGISTER_ALL_CPU_DISPATCH(name, nullptr) |
|
|
|
#define REGISTER_CUDA_DISPATCH(name, fn) \ |
|
static RegisterCUDADispatch<struct name> name ## __register(name, fn); |
|
|
|
#define REGISTER_HIP_DISPATCH(name, fn) \ |
|
static RegisterHIPDispatch<struct name> name ## __register(name, fn); |
|
|
|
#define REGISTER_MPS_DISPATCH(name, fn) \ |
|
static RegisterMPSDispatch<struct name> name ## __register(name, fn); |
|
|
|
|
|
|
|
#if defined(__CUDACC__) |
|
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) |
|
#elif defined(__HIPCC__) |
|
|
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) |
|
|
|
#elif defined(__OBJC__) && defined(USE_MPS) |
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) |
|
#elif defined(CPU_CAPABILITY) |
|
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) |
|
#define REGISTER_NO_AVX512_DISPATCH(name) \ |
|
REGISTER_AVX512_DISPATCH(name, nullptr) |
|
#endif |
|
|
|
|
|
}} |
|
|
|
|
|
#if defined(__clang__) |
|
#pragma clang diagnostic pop |
|
#endif |
|
|