File size: 13,008 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 |
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
#include <c10/util/Array.h>
#include <atomic>
#include <utility>
#include <variant>
// Implements instruction set specific function dispatch.
//
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
// compiled multiple times with different compiler flags (e.g. -mavx2). A
// DispatchStub contains a table of function pointers for a kernel. At runtime,
// the fastest available kernel is chosen based on the features reported by
// cpuinfo.
//
// Example:
//
// In native/MyKernel.h:
// using fn_type = void(*)(const Tensor& x);
// DECLARE_DISPATCH(fn_type, stub);
//
// In native/MyKernel.cpp
// DEFINE_DISPATCH(stub);
//
// In native/cpu/MyKernel.cpp:
// namespace {
// // use anonymous namespace so that different cpu versions won't conflict
// void kernel(const Tensor& x) { ... }
// }
// REGISTER_DISPATCH(stub, &kernel);
//
// To call:
// stub(kCPU, tensor);
//
// TODO: CPU instruction set selection should be folded into whatever
// the main dispatch mechanism is.
//
// Supported device types for registration:
// - CPU: Central Processing Unit
// - CUDA: NVIDIA GPUs
// - HIP: AMD GPUs
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
// - PrivateUse1: Reserved for private/custom device types
//
// If you want to update the list of supported devices, add a new dispatch_ptr
// member in DispatchStubImpl.h and update the get_call_ptr switch.
// As well you will need to update the inlined list in 'is_device_supported`
//
//
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
C10_CLANG_DIAGNOSTIC_PUSH()
C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
namespace at::native {
enum class CPUCapability {
DEFAULT = 0,
#if defined(HAVE_VSX_CPU_DEFINITION)
VSX = 1,
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
ZVECTOR = 1,
#else
AVX2 = 1,
AVX512 = 2,
#endif
NUM_OPTIONS
};
// Enum for error types
enum class ErrorType {
MissingDeviceKernel,
DeviceNotSupported
};
// Alias for the return type using std::variant
using DispatchResult = std::variant<void*, ErrorType>;
CPUCapability get_cpu_capability();
template <typename FnPtr, typename T>
struct DispatchStub;
/**
* The sole purpose of this class is to outline methods that don't need to be
* specialized or otherwise inlined and duplicated (by the compiler due to
* template expansion), since it causes size bloat if there are a significant
* number of specialization of the DispatchStub<> class.
*/
struct TORCH_API DispatchStubImpl {
// The DispatchStubImpl::try_get_call_ptr() method is used to get the call
// pointer for a given device type. If the call pointer is not found,
// DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
// The main difference between try_get_call_ptr() and get_call_ptr() is that
// try_get_call_ptr() will return the ErrorType and not raise an exception.
DispatchResult try_get_call_ptr(
c10::DeviceType device_type
, void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
// Analogous to try_get_call_ptr(), but it will return the ErrorType and not
// raise an exception.
DispatchResult try_choose_cpu_impl(
void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
void* get_call_ptr(
c10::DeviceType device_type
, void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
/**
* The CPU Dispatch actual method is chosen in decreasing order of preference by
* DispatchStubImpl::choose_cpu_impl() in case none is found by
* DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
*/
void* choose_cpu_impl(
void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
// Fixing dispatch error in Windows debug builds.
// See https://github.com/pytorch/pytorch/issues/22681 for more details.
#if defined(_MSC_VER) && defined(_DEBUG)
std::atomic<void*> cpu_dispatch_ptr;
void* cuda_dispatch_ptr;
void* hip_dispatch_ptr;
void* mps_dispatch_ptr;
void* privateuse1_dispatch_ptr;
#else
std::atomic<void*> cpu_dispatch_ptr{nullptr};
void* cuda_dispatch_ptr = nullptr;
void* hip_dispatch_ptr = nullptr;
void* mps_dispatch_ptr = nullptr;
void* privateuse1_dispatch_ptr = nullptr;
#endif
};
template <typename rT, typename T, typename... Args>
struct DispatchStub<rT (*)(Args...), T> {
using FnPtr = rT (*) (Args...);
DispatchStub() = default;
DispatchStub(const DispatchStub&) = delete;
DispatchStub& operator=(const DispatchStub&) = delete;
private:
FnPtr get_call_ptr(const c10::DeviceType device_type) {
return reinterpret_cast<FnPtr>(
impl.get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
)
);
}
public:
template <typename... ArgTypes>
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
FnPtr call_ptr = get_call_ptr(device_type);
return (*call_ptr)(std::forward<ArgTypes>(args)...);
}
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
// Returns true if the dispatcher has a kernel registered for this device
// type.
bool is_device_supported(const c10::DeviceType device_type) {
auto result = impl.try_get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
);
if (std::holds_alternative<ErrorType>(result)){
return false;
}
return true;
};
static TORCH_API FnPtr DEFAULT;
#ifdef HAVE_AVX512_CPU_DEFINITION
static TORCH_API FnPtr AVX512;
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
static TORCH_API FnPtr AVX2;
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
static TORCH_API FnPtr VSX;
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static TORCH_API FnPtr ZVECTOR;
#endif
private:
DispatchStubImpl impl;
};
namespace {
template <typename DispatchStub>
struct RegisterCUDADispatch {
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_cuda_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterMPSDispatch {
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_mps_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterHIPDispatch {
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
// TODO: make this point at hip_dispatch_ptr
stub.set_cuda_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterPRIVATEUSE1Dispatch {
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_privateuse1_dispatch_ptr(value);
}
};
} // anonymous namespace
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
// adding parentheses and using helper struct to get rid of the parentheses, do
// not work with MSVC. So do a `using`-declaration if you need to pass in such
// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
#define DECLARE_DISPATCH(fn, name) \
struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
name##_DECLARE_DISPATCH_type() = default; \
name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
}; \
extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
#ifdef HAVE_AVX512_CPU_DEFINITION
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
#else
#define REGISTER_AVX512_DISPATCH(name, fn)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
#else
#define REGISTER_AVX2_DISPATCH(name, fn)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
#else
#define REGISTER_VSX_DISPATCH(name, fn)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
#else
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
#endif
// Macro to register the same kernel for all CPU arch types. This is useful
// if a kernel does not benefit from being recompiled across different arch types.
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
REGISTER_AVX512_DISPATCH(name, fn) \
REGISTER_AVX2_DISPATCH(name, fn) \
REGISTER_VSX_DISPATCH(name, fn) \
REGISTER_ZVECTOR_DISPATCH(name, fn)
#define REGISTER_NO_CPU_DISPATCH(name) \
REGISTER_ALL_CPU_DISPATCH(name, nullptr)
#define REGISTER_CUDA_DISPATCH(name, fn) \
static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_HIP_DISPATCH(name, fn) \
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_MPS_DISPATCH(name, fn) \
static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
// NB: This macro must be used in an actual 'cu' file; if you try using
// it from a 'cpp' file it will not work!
#if defined(__CUDACC__)
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
#elif defined(__HIPCC__)
// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
// is HIP in the PyTorch HIPify build.
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
#elif defined(__OBJC__) && defined(USE_MPS)
// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
#elif defined(CPU_CAPABILITY)
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
#ifdef CPU_CAPABILITY_AVX512
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
#else
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
} // namespace at::native
C10_CLANG_DIAGNOSTIC_POP()
|