// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuBLAS_CPP_ #define DLIB_DNN_CuBLAS_CPP_ #ifdef DLIB_USE_CUDA #include "cublas_dlibapi.h" #include "cuda_utils.h" #include <cublas_v2.h> #include <vector> static const char* cublas_get_error_string(cublasStatus_t s) { switch(s) { case CUBLAS_STATUS_NOT_INITIALIZED: return "CUDA Runtime API initialization failed."; case CUBLAS_STATUS_ALLOC_FAILED: return "CUDA Resources could not be allocated."; default: return "A call to cuBLAS failed"; } } // Check the return value of a call to the cuBLAS runtime for an error condition. #define CHECK_CUBLAS(call) \ do{ \ const cublasStatus_t error = call; \ if (error != CUBLAS_STATUS_SUCCESS) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\ throw dlib::cublas_error(sout.str()); \ } \ }while(false) namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- class cublas_context { public: // not copyable cublas_context(const cublas_context&) = delete; cublas_context& operator=(const cublas_context&) = delete; cublas_context() { handles.resize(16); } ~cublas_context() { for (auto h : handles) { if (h) cublasDestroy(h); } } cublasHandle_t get_handle ( ) { int new_device_id; CHECK_CUDA(cudaGetDevice(&new_device_id)); // make room for more devices if needed if (new_device_id >= (long)handles.size()) handles.resize(new_device_id+16); // If we don't have a handle already for this device then make one if (!handles[new_device_id]) CHECK_CUBLAS(cublasCreate(&handles[new_device_id])); // Finally, return the handle for the current device return handles[new_device_id]; } private: std::vector<cublasHandle_t> handles; }; static cublasHandle_t context() { thread_local cublas_context c; return c.get_handle(); } // ----------------------------------------------------------------------------------- void gemm ( float beta, tensor& dest, float alpha, const tensor& lhs, bool trans_lhs, const tensor& rhs, bool trans_rhs ) { // Recall that BLAS uses column major order so to deal with that we flip the // order of the lhs and rhs arguments. const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; const int dest_nr = dest.num_samples(); const int dest_nc = dest.size()/dest_nr; const int lhs_nr = lhs.num_samples(); const int lhs_nc = lhs.size()/lhs_nr; const int rhs_nr = rhs.num_samples(); const int rhs_nc = rhs.size()/rhs_nr; if (trans_lhs && trans_rhs) { DLIB_ASSERT( dest_nr == lhs_nc && dest_nc == rhs_nr && lhs_nr == rhs_nc) } else if (!trans_lhs && trans_rhs) { DLIB_ASSERT( dest_nr == lhs_nr && dest_nc == rhs_nr && lhs_nc == rhs_nc) } else if (trans_lhs && !trans_rhs) { DLIB_ASSERT( dest_nr == lhs_nc && dest_nc == rhs_nc && lhs_nr == rhs_nr) } else { DLIB_ASSERT( dest_nr == lhs_nr && dest_nc == rhs_nc && lhs_nc == rhs_nr) } const int k = trans_rhs ? rhs_nc : rhs_nr; CHECK_CUBLAS(cublasSgemm(context(), transb, transa, dest_nc, dest_nr, k, &alpha, rhs.device(), rhs_nc, lhs.device(), lhs_nc, &beta, dest.device(),dest_nc)); } // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuBLAS_CPP_