|
|
|
|
|
#ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_ |
|
#define DLIB_DNN_CuDA_DATA_PTR_CPP_ |
|
|
|
#ifdef DLIB_USE_CUDA |
|
|
|
#include "cuda_data_ptr.h" |
|
#include "cuda_utils.h" |
|
|
|
namespace dlib |
|
{ |
|
namespace cuda |
|
{ |
|
|
|
|
|
|
|
weak_cuda_data_void_ptr:: |
|
weak_cuda_data_void_ptr( |
|
const cuda_data_void_ptr& ptr |
|
) : num(ptr.num), pdata(ptr.pdata) |
|
{ |
|
|
|
} |
|
|
|
|
|
|
|
cuda_data_void_ptr weak_cuda_data_void_ptr:: |
|
lock() const |
|
{ |
|
auto ptr = pdata.lock(); |
|
if (ptr) |
|
{ |
|
cuda_data_void_ptr temp; |
|
temp.pdata = ptr; |
|
temp.num = num; |
|
return temp; |
|
} |
|
else |
|
{ |
|
return cuda_data_void_ptr(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
cuda_data_void_ptr:: |
|
cuda_data_void_ptr( |
|
size_t n |
|
) : num(n) |
|
{ |
|
if (n == 0) |
|
return; |
|
|
|
void* data = nullptr; |
|
|
|
CHECK_CUDA(cudaMalloc(&data, n)); |
|
pdata.reset(data, [](void* ptr){ |
|
auto err = cudaFree(ptr); |
|
if(err!=cudaSuccess) |
|
std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl; |
|
}); |
|
} |
|
|
|
|
|
|
|
void memcpy( |
|
void* dest, |
|
const cuda_data_void_ptr& src, |
|
const size_t num |
|
) |
|
{ |
|
DLIB_ASSERT(num <= src.size()); |
|
if (src.size() != 0) |
|
{ |
|
CHECK_CUDA(cudaMemcpy(dest, src.data(), num, cudaMemcpyDefault)); |
|
} |
|
} |
|
|
|
|
|
|
|
void memcpy( |
|
void* dest, |
|
const cuda_data_void_ptr& src |
|
) |
|
{ |
|
memcpy(dest, src, src.size()); |
|
} |
|
|
|
|
|
|
|
void memcpy( |
|
cuda_data_void_ptr dest, |
|
const void* src, |
|
const size_t num |
|
) |
|
{ |
|
DLIB_ASSERT(num <= dest.size()); |
|
if (dest.size() != 0) |
|
{ |
|
CHECK_CUDA(cudaMemcpy(dest.data(), src, num, cudaMemcpyDefault)); |
|
} |
|
} |
|
|
|
|
|
|
|
void memcpy( |
|
cuda_data_void_ptr dest, |
|
const void* src |
|
) |
|
{ |
|
memcpy(dest,src,dest.size()); |
|
} |
|
|
|
|
|
|
|
class cudnn_device_buffer |
|
{ |
|
public: |
|
|
|
cudnn_device_buffer(const cudnn_device_buffer&) = delete; |
|
cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete; |
|
|
|
cudnn_device_buffer() |
|
{ |
|
buffers.resize(16); |
|
} |
|
~cudnn_device_buffer() |
|
{ |
|
} |
|
|
|
cuda_data_void_ptr get ( |
|
size_t size |
|
) |
|
{ |
|
int new_device_id; |
|
CHECK_CUDA(cudaGetDevice(&new_device_id)); |
|
|
|
if (new_device_id >= (long)buffers.size()) |
|
buffers.resize(new_device_id+16); |
|
|
|
|
|
|
|
cuda_data_void_ptr buff = buffers[new_device_id].lock(); |
|
if (!buff || buff.size() < size) |
|
{ |
|
buff = cuda_data_void_ptr(size); |
|
buffers[new_device_id] = buff; |
|
} |
|
|
|
|
|
return buff; |
|
} |
|
|
|
private: |
|
|
|
std::vector<weak_cuda_data_void_ptr> buffers; |
|
}; |
|
|
|
|
|
|
|
cuda_data_void_ptr device_global_buffer(size_t size) |
|
{ |
|
thread_local cudnn_device_buffer buffer; |
|
return buffer.get(size); |
|
} |
|
|
|
|
|
|
|
} |
|
} |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|