|
#pragma once |
|
|
|
#include <c10/util/Optional.h> |
|
#include <ATen/Config.h> |
|
#include <ATen/native/DispatchStub.h> |
|
|
|
|
|
namespace at { |
|
class Tensor; |
|
struct TensorIterator; |
|
|
|
namespace native { |
|
enum class TransposeType; |
|
} |
|
|
|
} |
|
|
|
namespace at { namespace native { |
|
|
|
enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss}; |
|
|
|
#if AT_BUILD_WITH_LAPACK() |
|
|
|
|
|
|
|
template <class scalar_t> |
|
void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info); |
|
|
|
template <class scalar_t> |
|
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info); |
|
|
|
template <class scalar_t, class value_t=scalar_t> |
|
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); |
|
|
|
template <class scalar_t> |
|
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); |
|
|
|
template <class scalar_t> |
|
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); |
|
|
|
template <class scalar_t> |
|
void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); |
|
|
|
template <class scalar_t, class value_t = scalar_t> |
|
void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); |
|
|
|
template <class scalar_t> |
|
void lapackGels(char trans, int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
scalar_t *work, int lwork, int *info); |
|
|
|
template <class scalar_t, class value_t = scalar_t> |
|
void lapackGelsd(int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
value_t *s, value_t rcond, int *rank, |
|
scalar_t* work, int lwork, |
|
value_t *rwork, int* iwork, int *info); |
|
|
|
template <class scalar_t, class value_t = scalar_t> |
|
void lapackGelsy(int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
int *jpvt, value_t rcond, int *rank, |
|
scalar_t *work, int lwork, value_t* rwork, int *info); |
|
|
|
template <class scalar_t, class value_t = scalar_t> |
|
void lapackGelss(int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
value_t *s, value_t rcond, int *rank, |
|
scalar_t *work, int lwork, |
|
value_t *rwork, int *info); |
|
|
|
template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t> |
|
struct lapackLstsq_impl; |
|
|
|
template <class scalar_t, class value_t> |
|
struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> { |
|
static void call( |
|
char trans, int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
scalar_t *work, int lwork, int *info, |
|
int *jpvt, value_t rcond, int *rank, value_t* rwork, |
|
value_t *s, |
|
int *iwork |
|
) { |
|
lapackGels<scalar_t>( |
|
trans, m, n, nrhs, |
|
a, lda, b, ldb, |
|
work, lwork, info); |
|
} |
|
}; |
|
|
|
template <class scalar_t, class value_t> |
|
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> { |
|
static void call( |
|
char trans, int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
scalar_t *work, int lwork, int *info, |
|
int *jpvt, value_t rcond, int *rank, value_t* rwork, |
|
value_t *s, |
|
int *iwork |
|
) { |
|
lapackGelsy<scalar_t, value_t>( |
|
m, n, nrhs, |
|
a, lda, b, ldb, |
|
jpvt, rcond, rank, |
|
work, lwork, rwork, info); |
|
} |
|
}; |
|
|
|
template <class scalar_t, class value_t> |
|
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> { |
|
static void call( |
|
char trans, int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
scalar_t *work, int lwork, int *info, |
|
int *jpvt, value_t rcond, int *rank, value_t* rwork, |
|
value_t *s, |
|
int *iwork |
|
) { |
|
lapackGelsd<scalar_t, value_t>( |
|
m, n, nrhs, |
|
a, lda, b, ldb, |
|
s, rcond, rank, |
|
work, lwork, |
|
rwork, iwork, info); |
|
} |
|
}; |
|
|
|
template <class scalar_t, class value_t> |
|
struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> { |
|
static void call( |
|
char trans, int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
scalar_t *work, int lwork, int *info, |
|
int *jpvt, value_t rcond, int *rank, value_t* rwork, |
|
value_t *s, |
|
int *iwork |
|
) { |
|
lapackGelss<scalar_t, value_t>( |
|
m, n, nrhs, |
|
a, lda, b, ldb, |
|
s, rcond, rank, |
|
work, lwork, |
|
rwork, info); |
|
} |
|
}; |
|
|
|
template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t> |
|
void lapackLstsq( |
|
char trans, int m, int n, int nrhs, |
|
scalar_t *a, int lda, scalar_t *b, int ldb, |
|
scalar_t *work, int lwork, int *info, |
|
int *jpvt, value_t rcond, int *rank, value_t* rwork, |
|
value_t *s, |
|
int *iwork |
|
) { |
|
lapackLstsq_impl<driver_type, scalar_t, value_t>::call( |
|
trans, m, n, nrhs, |
|
a, lda, b, ldb, |
|
work, lwork, info, |
|
jpvt, rcond, rank, rwork, |
|
s, |
|
iwork); |
|
} |
|
|
|
template <class scalar_t> |
|
void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); |
|
|
|
template <class scalar_t> |
|
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); |
|
|
|
template <class scalar_t> |
|
void lapackLdlHermitian( |
|
char uplo, |
|
int n, |
|
scalar_t* a, |
|
int lda, |
|
int* ipiv, |
|
scalar_t* work, |
|
int lwork, |
|
int* info); |
|
|
|
template <class scalar_t> |
|
void lapackLdlSymmetric( |
|
char uplo, |
|
int n, |
|
scalar_t* a, |
|
int lda, |
|
int* ipiv, |
|
scalar_t* work, |
|
int lwork, |
|
int* info); |
|
|
|
template <class scalar_t> |
|
void lapackLdlSolveHermitian( |
|
char uplo, |
|
int n, |
|
int nrhs, |
|
scalar_t* a, |
|
int lda, |
|
int* ipiv, |
|
scalar_t* b, |
|
int ldb, |
|
int* info); |
|
|
|
template <class scalar_t> |
|
void lapackLdlSolveSymmetric( |
|
char uplo, |
|
int n, |
|
int nrhs, |
|
scalar_t* a, |
|
int lda, |
|
int* ipiv, |
|
scalar_t* b, |
|
int ldb, |
|
int* info); |
|
|
|
template<class scalar_t, class value_t=scalar_t> |
|
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); |
|
#endif |
|
|
|
#if AT_BUILD_WITH_BLAS() |
|
template <class scalar_t> |
|
void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); |
|
#endif |
|
|
|
using cholesky_fn = void (*)(const Tensor& , const Tensor& , bool ); |
|
DECLARE_DISPATCH(cholesky_fn, cholesky_stub); |
|
|
|
using cholesky_inverse_fn = Tensor& (*)(Tensor& , Tensor& , bool ); |
|
|
|
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub); |
|
|
|
using linalg_eig_fn = void (*)(Tensor& , Tensor& , Tensor& , const Tensor& , bool ); |
|
|
|
DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub); |
|
|
|
using geqrf_fn = void (*)(const Tensor& , const Tensor& ); |
|
DECLARE_DISPATCH(geqrf_fn, geqrf_stub); |
|
|
|
using orgqr_fn = Tensor& (*)(Tensor& , const Tensor& ); |
|
DECLARE_DISPATCH(orgqr_fn, orgqr_stub); |
|
|
|
using ormqr_fn = void (*)(const Tensor& , const Tensor& , const Tensor& , bool , bool ); |
|
DECLARE_DISPATCH(ormqr_fn, ormqr_stub); |
|
|
|
using linalg_eigh_fn = void (*)( |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& , |
|
bool , |
|
bool ); |
|
DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub); |
|
|
|
using lstsq_fn = void (*)( |
|
const Tensor& , |
|
Tensor& , |
|
Tensor& , |
|
Tensor& , |
|
Tensor& , |
|
double , |
|
std::string ); |
|
DECLARE_DISPATCH(lstsq_fn, lstsq_stub); |
|
|
|
using triangular_solve_fn = void (*)( |
|
const Tensor& , |
|
const Tensor& , |
|
bool , |
|
bool , |
|
TransposeType , |
|
bool ); |
|
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); |
|
|
|
using lu_factor_fn = void (*)( |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& , |
|
bool ); |
|
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub); |
|
|
|
using unpack_pivots_fn = void(*)( |
|
TensorIterator& iter, |
|
const int64_t dim_size, |
|
const int64_t max_pivot); |
|
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub); |
|
|
|
using lu_solve_fn = void (*)( |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& , |
|
TransposeType ); |
|
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub); |
|
|
|
using ldl_factor_fn = void (*)( |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& , |
|
bool , |
|
bool ); |
|
DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub); |
|
|
|
using svd_fn = void (*)( |
|
const Tensor& , |
|
const bool , |
|
const bool , |
|
const c10::optional<c10::string_view>& , |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& ); |
|
DECLARE_DISPATCH(svd_fn, svd_stub); |
|
|
|
using ldl_solve_fn = void (*)( |
|
const Tensor& , |
|
const Tensor& , |
|
const Tensor& , |
|
bool , |
|
bool ); |
|
DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub); |
|
}} |
|
|