|
#pragma once |
|
|
|
#include <ATen/BlasBackend.h> |
|
#include <ATen/CPUGeneratorImpl.h> |
|
#include <ATen/DeviceAccelerator.h> |
|
#include <ATen/LinalgBackend.h> |
|
#include <ATen/core/ATenGeneral.h> |
|
#include <ATen/core/DeprecatedTypeProperties.h> |
|
#include <ATen/core/Generator.h> |
|
#include <ATen/core/LegacyTypeDispatch.h> |
|
#include <ATen/detail/AcceleratorHooksInterface.h> |
|
#include <ATen/detail/CUDAHooksInterface.h> |
|
#include <ATen/detail/HIPHooksInterface.h> |
|
#include <ATen/detail/IPUHooksInterface.h> |
|
#include <ATen/detail/MAIAHooksInterface.h> |
|
#include <ATen/detail/MPSHooksInterface.h> |
|
#include <ATen/detail/MTIAHooksInterface.h> |
|
#include <ATen/detail/PrivateUse1HooksInterface.h> |
|
#include <ATen/detail/XPUHooksInterface.h> |
|
#include <c10/core/QEngine.h> |
|
#include <c10/core/impl/DeviceGuardImplInterface.h> |
|
#include <c10/util/CallOnce.h> |
|
#include <c10/util/Exception.h> |
|
#include <c10/util/env.h> |
|
#include <c10/util/irange.h> |
|
|
|
#include <cstdint> |
|
#include <mutex> |
|
|
|
namespace at { |
|
|
|
class Tensor; |
|
|
|
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM }; |
|
|
|
class TORCH_API Context { |
|
public: |
|
Context(); |
|
|
|
const Generator& defaultGenerator(Device device) { |
|
c10::DeviceType device_type = device.type(); |
|
initCUDAIfNeeded(device_type); |
|
initHIPIfNeeded(device_type); |
|
if (device_type == at::kCPU) { |
|
return at::detail::getDefaultCPUGenerator(); |
|
} else if (device_type == at::kCUDA) { |
|
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index()); |
|
} else if (device_type == at::kMPS) { |
|
return at::detail::getMPSHooks().getDefaultMPSGenerator(); |
|
} else if (device_type == at::kXPU) { |
|
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index()); |
|
} else if (device_type == at::kIPU) { |
|
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index()); |
|
} else if (device_type == at::kPrivateUse1) { |
|
return at::GetPrivateUse1HooksInterface()->getDefaultGenerator( |
|
device.index()); |
|
} else { |
|
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); |
|
} |
|
} |
|
const AcceleratorHooksInterface& getAcceleratorHooksInterface( |
|
std::optional<c10::DeviceType> opt_device_type = c10::nullopt) { |
|
c10::DeviceType device_type = opt_device_type.has_value() |
|
? opt_device_type.value() |
|
: at::getAccelerator(true).value(); |
|
if (device_type == at::kCUDA) { |
|
return at::detail::getCUDAHooks(); |
|
} else if (device_type == at::kMPS) { |
|
return at::detail::getMPSHooks(); |
|
} else if (device_type == at::kPrivateUse1) { |
|
return at::detail::getPrivateUse1Hooks(); |
|
} else if (device_type == at::kMTIA) { |
|
return at::detail::getMTIAHooks(); |
|
} else { |
|
AT_ERROR( |
|
c10::DeviceTypeName(device_type), " device type not an accelerator."); |
|
} |
|
} |
|
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) { |
|
initCUDAIfNeeded(device_type); |
|
initHIPIfNeeded(device_type); |
|
initXPUIfNeeded(device_type); |
|
if (device_type == at::kCPU) { |
|
return c10::DeviceType::CPU; |
|
} else if (device_type == at::kCUDA) { |
|
return at::detail::getCUDAHooks().getDeviceFromPtr(data); |
|
} else if (device_type == at::kXPU) { |
|
return at::detail::getXPUHooks().getDeviceFromPtr(data); |
|
} else if (device_type == at::kPrivateUse1) { |
|
return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data); |
|
} else { |
|
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); |
|
} |
|
} |
|
static bool isPinnedPtr(const void* data) { |
|
return detail::getCUDAHooks().isPinnedPtr(data); |
|
} |
|
static bool hasOpenMP(); |
|
static bool hasMKL(); |
|
static bool hasLAPACK(); |
|
static bool hasMKLDNN(); |
|
static bool hasMAGMA() { |
|
return detail::getCUDAHooks().hasMAGMA(); |
|
} |
|
static bool hasCUDA() { |
|
return detail::getCUDAHooks().hasCUDA(); |
|
} |
|
static bool hasMTIA() { |
|
return detail::getMTIAHooks().hasMTIA(); |
|
} |
|
static bool hasCUDART() { |
|
return detail::getCUDAHooks().hasCUDART(); |
|
} |
|
static long versionCUDART() { |
|
return detail::getCUDAHooks().versionCUDART(); |
|
} |
|
static bool hasCuDNN() { |
|
return detail::getCUDAHooks().hasCuDNN(); |
|
} |
|
static long versionCuDNN() { |
|
return detail::getCUDAHooks().versionCuDNN(); |
|
} |
|
static bool hasCuSOLVER() { |
|
return detail::getCUDAHooks().hasCuSOLVER(); |
|
} |
|
static bool hasCuBLASLt() { |
|
return detail::getCUDAHooks().hasCuBLASLt(); |
|
} |
|
static bool hasHIP() { |
|
return detail::getHIPHooks().hasHIP(); |
|
} |
|
static bool hasMPS() { |
|
return detail::getMPSHooks().hasMPS(); |
|
} |
|
static bool hasIPU() { |
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); |
|
} |
|
static bool hasXLA() { |
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA); |
|
} |
|
static bool hasXPU() { |
|
return detail::getXPUHooks().hasXPU(); |
|
} |
|
static bool hasLazy() { |
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy); |
|
} |
|
static bool hasMAIA() { |
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA); |
|
} |
|
|
|
|
|
void lazyInitCUDA() { |
|
c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); }); |
|
} |
|
void lazyInitHIP() { |
|
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); |
|
} |
|
void lazyInitXPU() { |
|
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); }); |
|
} |
|
void lazyInitMTIA() { |
|
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); }); |
|
} |
|
void lazyInitPrivateUse1() { |
|
c10::call_once(thp_init, [&] { |
|
if (isPrivateUse1HooksRegistered()) { |
|
at::GetPrivateUse1HooksInterface()->initPrivateUse1(); |
|
} |
|
}); |
|
} |
|
static const at::cuda::NVRTC& getNVRTC() { |
|
return detail::getCUDAHooks().nvrtc(); |
|
} |
|
|
|
static bool setFlushDenormal(bool on); |
|
|
|
|
|
|
|
|
|
|
|
bool userEnabledCuDNN() const; |
|
void setUserEnabledCuDNN(bool e); |
|
bool userEnabledMkldnn() const; |
|
void setUserEnabledMkldnn(bool e); |
|
bool benchmarkCuDNN() const; |
|
void setBenchmarkCuDNN(bool); |
|
int benchmarkLimitCuDNN() const; |
|
void setBenchmarkLimitCuDNN(int); |
|
bool deterministicCuDNN() const; |
|
void setDeterministicCuDNN(bool); |
|
bool userEnabledNNPACK() const; |
|
void setUserEnabledNNPACK(bool e); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void setSDPUseFlash(bool); |
|
bool userEnabledFlashSDP() const; |
|
|
|
void setSDPUseMemEfficient(bool); |
|
bool userEnabledMemEfficientSDP() const; |
|
|
|
void setSDPUseMath(bool); |
|
bool userEnabledMathSDP() const; |
|
|
|
void setSDPUseCuDNN(bool); |
|
bool userEnabledCuDNNSDP() const; |
|
|
|
at::LinalgBackend linalgPreferredBackend() const; |
|
void setLinalgPreferredBackend(at::LinalgBackend); |
|
|
|
at::BlasBackend blasPreferredBackend() const; |
|
void setBlasPreferredBackend(at::BlasBackend); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool deterministicAlgorithms() const; |
|
bool deterministicAlgorithmsWarnOnly() const; |
|
void setDeterministicAlgorithms(bool, bool); |
|
bool deterministicFillUninitializedMemory() const; |
|
void setDeterministicFillUninitializedMemory(bool); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void alertNotDeterministic(c10::string_view const& caller); |
|
|
|
|
|
|
|
|
|
|
|
void alertCuBLASConfigNotDeterministic() const; |
|
|
|
void setFloat32MatmulPrecision(const std::string& s); |
|
bool allowTF32CuDNN() const; |
|
void setAllowTF32CuDNN(bool); |
|
bool allowTF32CuBLAS() const; |
|
void setAllowTF32CuBLAS(bool); |
|
Float32MatmulPrecision float32MatmulPrecision() const; |
|
void setFloat32MatmulPrecision(Float32MatmulPrecision p); |
|
bool allowFP16ReductionCuBLAS() const; |
|
void setAllowFP16ReductionCuBLAS(bool); |
|
bool allowBF16ReductionCuBLAS() const; |
|
void setAllowBF16ReductionCuBLAS(bool); |
|
at::QEngine qEngine() const; |
|
void setQEngine(at::QEngine e); |
|
static const std::vector<at::QEngine>& supportedQEngines(); |
|
static bool isXNNPACKAvailable(); |
|
void setCheckSparseTensorInvariants(bool e); |
|
bool checkSparseTensorInvariants() const; |
|
|
|
|
|
|
|
void setReleaseWeightsWhenPrepacking(bool e); |
|
bool releaseWeightsWhenPrepacking() const; |
|
|
|
void setDisplayVmapFallbackWarnings(bool enabled); |
|
bool areVmapFallbackWarningsEnabled() const; |
|
|
|
void setDefaultMobileCPUAllocator(); |
|
void unsetDefaultMobileCPUAllocator(); |
|
bool allowFP16ReductionCPU() const; |
|
void setAllowFP16ReductionCPU(bool); |
|
|
|
private: |
|
void initCUDAIfNeeded(c10::DeviceType p) { |
|
if (p == c10::DeviceType::CUDA) { |
|
lazyInitCUDA(); |
|
} |
|
} |
|
void initHIPIfNeeded(c10::DeviceType p) { |
|
if (p == c10::DeviceType::HIP) { |
|
lazyInitHIP(); |
|
} |
|
} |
|
void initXPUIfNeeded(c10::DeviceType p) { |
|
if (p == c10::DeviceType::XPU) { |
|
lazyInitXPU(); |
|
} |
|
} |
|
static bool checkCuBLASConfigDeterministic(); |
|
c10::once_flag thc_init; |
|
c10::once_flag thh_init; |
|
c10::once_flag thx_init; |
|
c10::once_flag th_mtia_init; |
|
c10::once_flag thp_init; |
|
bool enabled_cudnn = true; |
|
bool deterministic_cudnn = false; |
|
bool _deterministic_algorithms = false; |
|
bool _deterministic_algorithms_warn_only = false; |
|
bool _deterministic_fill_uninitialized_memory = true; |
|
bool enabled_flashSDP = true; |
|
bool enabled_mem_efficientSDP = true; |
|
bool enabled_mathSDP = true; |
|
bool enabled_cudnnSDP = false; |
|
#ifdef USE_ROCM |
|
bool benchmark_cudnn = true; |
|
#else |
|
bool benchmark_cudnn = false; |
|
#endif |
|
Float32MatmulPrecision float32_matmul_precision = |
|
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true |
|
? at::Float32MatmulPrecision::HIGH |
|
: at::Float32MatmulPrecision::HIGHEST; |
|
int benchmark_limit_cudnn = 10; |
|
bool allow_tf32_cudnn = true; |
|
bool allow_fp16_reduction_cublas = true; |
|
bool allow_bf16_reduction_cublas = true; |
|
bool enabled_mkldnn = true; |
|
bool enabled_nnpack = true; |
|
at::LinalgBackend linalg_preferred_backend = |
|
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true |
|
? at::LinalgBackend::Cusolver |
|
: at::LinalgBackend::Default; |
|
at::BlasBackend blas_preferred_backend = |
|
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true || |
|
c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) |
|
? at::BlasBackend::Cublaslt |
|
: at::BlasBackend::Cublas; |
|
#ifdef C10_MOBILE |
|
bool release_original_weights = true; |
|
#else |
|
bool release_original_weights = false; |
|
#endif |
|
bool display_vmap_fallback_warnings_ = false; |
|
std::optional<at::QEngine> quantized_engine = c10::nullopt; |
|
bool enable_sparse_tensor_invariant_checks = false; |
|
bool allow_fp16_reduction_cpu = false; |
|
|
|
Allocator* prev_allocator_ptr_{nullptr}; |
|
}; |
|
|
|
TORCH_API Context& globalContext(); |
|
|
|
static inline void init() { |
|
globalContext(); |
|
} |
|
|
|
TORCH_API Allocator* getCPUAllocator(); |
|
|
|
static inline DeprecatedTypeProperties& getDeprecatedTypeProperties( |
|
Backend p, |
|
ScalarType s) { |
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
|
p, s); |
|
} |
|
|
|
static inline DeprecatedTypeProperties& CPU(ScalarType s) { |
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
|
Backend::CPU, s); |
|
} |
|
|
|
static inline DeprecatedTypeProperties& CUDA(ScalarType s) { |
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
|
Backend::CUDA, s); |
|
} |
|
|
|
static inline DeprecatedTypeProperties& HIP(ScalarType s) { |
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
|
Backend::HIP, s); |
|
} |
|
|
|
static inline DeprecatedTypeProperties& MPS(ScalarType s) { |
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
|
Backend::MPS, s); |
|
} |
|
|
|
static inline bool hasCUDA() { |
|
return globalContext().hasCUDA(); |
|
} |
|
|
|
static inline bool hasMTIA() { |
|
return globalContext().hasMTIA(); |
|
} |
|
|
|
static inline bool hasHIP() { |
|
return globalContext().hasHIP(); |
|
} |
|
|
|
static inline bool hasIPU() { |
|
return globalContext().hasIPU(); |
|
} |
|
|
|
static inline bool hasXLA() { |
|
return globalContext().hasXLA(); |
|
} |
|
|
|
static inline bool hasMPS() { |
|
return globalContext().hasMPS(); |
|
} |
|
|
|
static inline bool hasMAIA() { |
|
return globalContext().hasMAIA(); |
|
} |
|
|
|
static inline bool hasXPU() { |
|
return globalContext().hasXPU(); |
|
} |
|
|
|
|
|
static inline size_t getNumGPUs() { |
|
|
|
|
|
|
|
|
|
if (hasCUDA() && hasHIP()) { |
|
throw std::runtime_error( |
|
"Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades " |
|
"to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually " |
|
"means HIP. Rebuild PyTorch with one or the other disabled."); |
|
} else if (hasCUDA()) { |
|
return detail::getCUDAHooks().getNumGPUs(); |
|
} else if (hasHIP()) { |
|
return detail::getHIPHooks().getNumGPUs(); |
|
} else { |
|
return 0; |
|
} |
|
} |
|
|
|
static inline bool hasOpenMP() { |
|
return globalContext().hasOpenMP(); |
|
} |
|
|
|
static inline bool hasMKL() { |
|
return globalContext().hasMKL(); |
|
} |
|
|
|
static inline bool hasLAPACK() { |
|
return globalContext().hasLAPACK(); |
|
} |
|
|
|
static inline bool hasMAGMA() { |
|
return globalContext().hasMAGMA(); |
|
} |
|
|
|
static inline bool hasMKLDNN() { |
|
return globalContext().hasMKLDNN(); |
|
} |
|
|
|
static inline void manual_seed(uint64_t seed) { |
|
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU); |
|
{ |
|
|
|
std::lock_guard<std::mutex> lock(gen.mutex()); |
|
gen.set_current_seed(seed); |
|
} |
|
|
|
|
|
const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs(); |
|
if (hasCUDA() && cuda_num_gpus > 0) { |
|
for (const auto i : c10::irange(cuda_num_gpus)) { |
|
auto cuda_gen = globalContext().defaultGenerator( |
|
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i))); |
|
{ |
|
|
|
std::lock_guard<std::mutex> lock(cuda_gen.mutex()); |
|
cuda_gen.set_current_seed(seed); |
|
} |
|
} |
|
} |
|
|
|
const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs(); |
|
if (hasXPU() && xpu_num_gpus) { |
|
for (const auto i : c10::irange(xpu_num_gpus)) { |
|
auto xpu_gen = globalContext().defaultGenerator( |
|
Device(at::kXPU, static_cast<c10::DeviceIndex>(i))); |
|
{ |
|
|
|
std::lock_guard<std::mutex> lock(xpu_gen.mutex()); |
|
xpu_gen.set_current_seed(seed); |
|
} |
|
} |
|
} |
|
|
|
if (hasMPS()) { |
|
auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS); |
|
|
|
std::lock_guard<std::mutex> lock(mps_gen.mutex()); |
|
mps_gen.set_current_seed(seed); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API NoTF32Guard { |
|
NoTF32Guard(); |
|
~NoTF32Guard(); |
|
static bool should_disable_tf32(); |
|
|
|
private: |
|
bool changed = false; |
|
}; |
|
|
|
struct TORCH_API ROCmBackwardPassGuard { |
|
ROCmBackwardPassGuard(); |
|
~ROCmBackwardPassGuard(); |
|
static bool is_backward_pass(); |
|
}; |
|
|
|
} |
|
|