|
|
|
|
|
#include <torch/extension.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <ATen/cuda/CUDAUtils.h> |
|
#include <vector> |
|
#include <stdexcept> |
|
#include <string> |
|
|
|
#ifdef SLANG_LLVM |
|
#include "slang-llvm.h" |
|
#else |
|
# if SLANG_GCC_FAMILY && __GNUC__ < 6 |
|
# include <cmath> |
|
# define SLANG_PRELUDE_STD std:: |
|
# else |
|
# include <math.h> |
|
# define SLANG_PRELUDE_STD |
|
# endif |
|
|
|
# include <assert.h> |
|
# include <stdlib.h> |
|
# include <string.h> |
|
# include <stdint.h> |
|
#endif |
|
|
|
#include "../source/core/slang-string.h" |
|
|
|
#if defined(_MSC_VER) |
|
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport) |
|
#else |
|
# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default"))) |
|
|
|
#endif |
|
|
|
#ifdef __cplusplus |
|
# define SLANG_PRELUDE_EXTERN_C extern "C" |
|
# define SLANG_PRELUDE_EXTERN_C_START extern "C" { |
|
# define SLANG_PRELUDE_EXTERN_C_END } |
|
#else |
|
# define SLANG_PRELUDE_EXTERN_C |
|
# define SLANG_PRELUDE_EXTERN_C_START |
|
# define SLANG_PRELUDE_EXTERN_C_END |
|
#endif |
|
|
|
#define SLANG_PRELUDE_NAMESPACE |
|
|
|
#ifndef SLANG_NO_THROW |
|
# define SLANG_NO_THROW |
|
#endif |
|
#ifndef SLANG_STDCALL |
|
# define SLANG_STDCALL |
|
#endif |
|
#ifndef SLANG_MCALL |
|
# define SLANG_MCALL SLANG_STDCALL |
|
#endif |
|
#ifndef SLANG_FORCE_INLINE |
|
# define SLANG_FORCE_INLINE inline |
|
#endif |
|
#include "slang-cpp-types-core.h" |
|
#include "slang-cpp-scalar-intrinsics.h" |
|
|
|
|
|
static const int kSlangTorchTensorMaxDim = 5; |
|
|
|
struct TensorView |
|
{ |
|
uint8_t* data; |
|
uint32_t strides[kSlangTorchTensorMaxDim]; |
|
uint32_t sizes[kSlangTorchTensorMaxDim]; |
|
uint32_t dimensionCount; |
|
}; |
|
|
|
|
|
TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarType targetScalarType, bool requireContiguous) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!val.device().is_cuda()) |
|
throw std::runtime_error(std::string(name).append(": tensor is not on CUDA device.").c_str()); |
|
|
|
|
|
if (val.dtype() != targetScalarType) |
|
throw std::runtime_error(std::string(name).append(": tensor is not of the expected type.").c_str()); |
|
|
|
|
|
if (requireContiguous && !val.is_contiguous()) |
|
throw std::runtime_error(std::string(name).append(": tensor is not contiguous.").c_str()); |
|
|
|
TensorView res = {}; |
|
res.dimensionCount = val.dim(); |
|
res.data = nullptr; |
|
size_t elementSize = 4; |
|
|
|
switch (val.scalar_type()) |
|
{ |
|
case torch::kInt8: |
|
case torch::kUInt8: |
|
elementSize = 1; |
|
res.data = (uint8_t*)val.data_ptr<uint8_t>(); |
|
break; |
|
case torch::kBFloat16: |
|
elementSize = 2; |
|
res.data = (uint8_t*)val.data_ptr<torch::BFloat16>(); |
|
break; |
|
case torch::kFloat16: |
|
elementSize = 2; |
|
res.data = (uint8_t*)val.data_ptr<at::Half>(); |
|
break; |
|
case torch::kInt16: |
|
elementSize = 2; |
|
res.data = (uint8_t*)val.data_ptr<int16_t>(); |
|
break; |
|
case torch::kFloat32: |
|
elementSize = 4; |
|
res.data = (uint8_t*)val.data_ptr<float>(); |
|
break; |
|
case torch::kInt32: |
|
elementSize = 4; |
|
res.data = (uint8_t*)val.data_ptr<int32_t>(); |
|
break; |
|
case torch::kFloat64: |
|
elementSize = 8; |
|
res.data = (uint8_t*)val.data_ptr<double>(); |
|
break; |
|
case torch::kInt64: |
|
elementSize = 8; |
|
res.data = (uint8_t*)val.data_ptr<int64_t>(); |
|
break; |
|
case torch::kBool: |
|
elementSize = 1; |
|
res.data = (uint8_t*)val.data_ptr<bool>(); |
|
break; |
|
} |
|
|
|
if (val.dim() > kSlangTorchTensorMaxDim) |
|
throw std::runtime_error(std::string(name).append(": number of dimensions exceeds limit (").append(std::to_string(kSlangTorchTensorMaxDim)).append(")").c_str()); |
|
|
|
bool isEmpty = true; |
|
for (int i = 0; i < val.dim(); ++i) |
|
{ |
|
res.strides[i] = val.stride(i) * elementSize; |
|
if (res.strides[i] == 0) |
|
throw std::runtime_error(std::string(name).append(": tensors with broadcasted dimensions are not supported (use tensor.contiguous() to make tensor whole)").c_str()); |
|
|
|
res.sizes[i] = val.size(i); |
|
if (res.sizes[i] > 0) |
|
isEmpty = false; |
|
} |
|
|
|
if (!res.data && !isEmpty) |
|
throw std::runtime_error(std::string(name).append(": data pointer is invalid.").c_str()); |
|
|
|
return res; |
|
} |
|
|
|
#define SLANG_PRELUDE_EXPORT |
|
|