|
|
|
|
|
#pragma once |
|
#include <c10/core/Allocator.h> |
|
#include <c10/macros/Macros.h> |
|
#include <c10/util/Exception.h> |
|
|
|
|
|
#ifdef __OBJC__ |
|
#include <Foundation/Foundation.h> |
|
#include <Metal/Metal.h> |
|
#include <MetalPerformanceShaders/MetalPerformanceShaders.h> |
|
typedef id<MTLDevice> MTLDevice_t; |
|
typedef id<MTLLibrary> MTLLibrary_t; |
|
typedef id<MTLComputePipelineState> MTLComputePipelineState_t; |
|
typedef id<MTLLibrary> MTLLibrary_t; |
|
#else |
|
typedef void* MTLDevice; |
|
typedef void* MTLDevice_t; |
|
typedef void* MTLLibrary_t; |
|
typedef void* MTLComputePipelineState_t; |
|
typedef void* MTLLibrary_t; |
|
#endif |
|
|
|
namespace at::mps { |
|
|
|
|
|
enum class MacOSVersion : uint32_t { |
|
MACOS_VER_13_0_PLUS = 0, |
|
MACOS_VER_13_1_PLUS, |
|
MACOS_VER_13_2_PLUS, |
|
MACOS_VER_13_3_PLUS, |
|
MACOS_VER_14_0_PLUS, |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TORCH_API MPSDevice { |
|
public: |
|
|
|
|
|
|
|
MPSDevice(MPSDevice& other) = delete; |
|
|
|
|
|
|
|
void operator=(const MPSDevice&) = delete; |
|
|
|
|
|
|
|
static MPSDevice* getInstance(); |
|
|
|
|
|
|
|
MTLDevice_t device() { |
|
return _mtl_device; |
|
} |
|
|
|
|
|
|
|
bool isMacOS13Plus(MacOSVersion version) const; |
|
|
|
MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel); |
|
MTLLibrary_t getMetalIndexingLibrary(); |
|
|
|
~MPSDevice(); |
|
|
|
private: |
|
static MPSDevice* _device; |
|
MTLDevice_t _mtl_device; |
|
MTLLibrary_t _mtl_indexing_library; |
|
MPSDevice(); |
|
}; |
|
|
|
TORCH_API bool is_available(); |
|
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS); |
|
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); |
|
|
|
} |
|
|