File size: 5,208 Bytes
7e50900 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TransposeType.h>
#include <c10/util/complex.h>
#include <c10/core/ScalarType.h>
#include <c10/core/Scalar.h>
namespace at {
namespace native {
namespace cpublas {
namespace internal {
void normalize_last_dims(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc);
} // namespace internal
using gemm_fn = void(*)(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc);
DECLARE_DISPATCH(gemm_fn, gemm_stub);
template <typename scalar_t>
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
at::opmath_type<scalar_t> alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
at::opmath_type<scalar_t> beta,
scalar_t *c, int64_t ldc) {
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
gemm_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
double alpha,
const double *a, int64_t lda,
const double *b, int64_t ldb,
double beta,
double *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
float beta,
float *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::BFloat16 *a, int64_t lda,
const at::BFloat16 *b, int64_t ldb,
float beta,
at::BFloat16 *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
c10::complex<double> alpha,
const c10::complex<double> *a, int64_t lda,
const c10::complex<double> *b, int64_t ldb,
c10::complex<double> beta,
c10::complex<double> *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
c10::complex<float> alpha,
const c10::complex<float> *a, int64_t lda,
const c10::complex<float> *b, int64_t ldb,
c10::complex<float> beta,
c10::complex<float> *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t alpha,
const int64_t *a, int64_t lda,
const int64_t *b, int64_t ldb,
int64_t beta,
int64_t *c, int64_t ldc);
template <typename scalar_t>
void gemm_batched(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
const scalar_t * const *a, int64_t lda,
const scalar_t * const *b, int64_t ldb,
const scalar_t beta,
scalar_t * const *c, int64_t ldc);
template <typename scalar_t>
void gemm_batched_with_stride(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
const scalar_t *a, int64_t lda, int64_t batch_stride_a,
const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
scalar_t beta,
scalar_t *c, int64_t ldc, int64_t batch_stride_c);
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
DECLARE_DISPATCH(axpy_fn, axpy_stub);
template<typename scalar_t>
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
if(n == 1)
{
incx = 1;
incy = 1;
}
axpy_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
n, a, x, incx, y, incy);
}
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
DECLARE_DISPATCH(copy_fn, copy_stub);
template<typename scalar_t>
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
if(n == 1)
{
incx = 1;
incy = 1;
}
copy_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
n, x, incx, y, incy);
}
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
}}} // namespace at::native::cpublas
|