|
|
|
|
|
#ifndef DLIB_DNN_CuBLAS_CPP_ |
|
#define DLIB_DNN_CuBLAS_CPP_ |
|
|
|
#ifdef DLIB_USE_CUDA |
|
|
|
#include "cublas_dlibapi.h" |
|
#include "cuda_utils.h" |
|
|
|
#include <cublas_v2.h> |
|
#include <vector> |
|
|
|
static const char* cublas_get_error_string(cublasStatus_t s) |
|
{ |
|
switch(s) |
|
{ |
|
case CUBLAS_STATUS_NOT_INITIALIZED: |
|
return "CUDA Runtime API initialization failed."; |
|
case CUBLAS_STATUS_ALLOC_FAILED: |
|
return "CUDA Resources could not be allocated."; |
|
default: |
|
return "A call to cuBLAS failed"; |
|
} |
|
} |
|
|
|
|
|
#define CHECK_CUBLAS(call) \ |
|
do{ \ |
|
const cublasStatus_t error = call; \ |
|
if (error != CUBLAS_STATUS_SUCCESS) \ |
|
{ \ |
|
std::ostringstream sout; \ |
|
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ |
|
sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\ |
|
throw dlib::cublas_error(sout.str()); \ |
|
} \ |
|
}while(false) |
|
|
|
namespace dlib |
|
{ |
|
namespace cuda |
|
{ |
|
|
|
|
|
|
|
class cublas_context |
|
{ |
|
public: |
|
|
|
cublas_context(const cublas_context&) = delete; |
|
cublas_context& operator=(const cublas_context&) = delete; |
|
|
|
cublas_context() |
|
{ |
|
handles.resize(16); |
|
} |
|
~cublas_context() |
|
{ |
|
for (auto h : handles) |
|
{ |
|
if (h) |
|
cublasDestroy(h); |
|
} |
|
} |
|
|
|
cublasHandle_t get_handle ( |
|
) |
|
{ |
|
int new_device_id; |
|
CHECK_CUDA(cudaGetDevice(&new_device_id)); |
|
|
|
if (new_device_id >= (long)handles.size()) |
|
handles.resize(new_device_id+16); |
|
|
|
|
|
if (!handles[new_device_id]) |
|
CHECK_CUBLAS(cublasCreate(&handles[new_device_id])); |
|
|
|
|
|
return handles[new_device_id]; |
|
} |
|
|
|
private: |
|
|
|
std::vector<cublasHandle_t> handles; |
|
}; |
|
|
|
static cublasHandle_t context() |
|
{ |
|
thread_local cublas_context c; |
|
return c.get_handle(); |
|
} |
|
|
|
|
|
|
|
void gemm ( |
|
float beta, |
|
tensor& dest, |
|
float alpha, |
|
const tensor& lhs, |
|
bool trans_lhs, |
|
const tensor& rhs, |
|
bool trans_rhs |
|
) |
|
{ |
|
|
|
|
|
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; |
|
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; |
|
|
|
const int dest_nr = dest.num_samples(); |
|
const int dest_nc = dest.size()/dest_nr; |
|
const int lhs_nr = lhs.num_samples(); |
|
const int lhs_nc = lhs.size()/lhs_nr; |
|
const int rhs_nr = rhs.num_samples(); |
|
const int rhs_nc = rhs.size()/rhs_nr; |
|
if (trans_lhs && trans_rhs) |
|
{ |
|
DLIB_ASSERT( dest_nr == lhs_nc && |
|
dest_nc == rhs_nr && |
|
lhs_nr == rhs_nc) |
|
} |
|
else if (!trans_lhs && trans_rhs) |
|
{ |
|
DLIB_ASSERT( dest_nr == lhs_nr && |
|
dest_nc == rhs_nr && |
|
lhs_nc == rhs_nc) |
|
} |
|
else if (trans_lhs && !trans_rhs) |
|
{ |
|
DLIB_ASSERT( dest_nr == lhs_nc && |
|
dest_nc == rhs_nc && |
|
lhs_nr == rhs_nr) |
|
} |
|
else |
|
{ |
|
DLIB_ASSERT( dest_nr == lhs_nr && |
|
dest_nc == rhs_nc && |
|
lhs_nc == rhs_nr) |
|
} |
|
|
|
const int k = trans_rhs ? rhs_nc : rhs_nr; |
|
CHECK_CUBLAS(cublasSgemm(context(), |
|
transb, |
|
transa, |
|
dest_nc, dest_nr, k, |
|
&alpha, |
|
rhs.device(), rhs_nc, |
|
lhs.device(), lhs_nc, |
|
&beta, |
|
dest.device(),dest_nc)); |
|
} |
|
|
|
|
|
|
|
} |
|
} |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
|