Aging_MouthReplace / dlibs /dlib /cuda /cuda_data_ptr.cpp
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
5.03 kB
// Copyright (C) 2017 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_
#define DLIB_DNN_CuDA_DATA_PTR_CPP_
#ifdef DLIB_USE_CUDA
#include "cuda_data_ptr.h"
#include "cuda_utils.h"
namespace dlib
{
namespace cuda
{
// ----------------------------------------------------------------------------------------
weak_cuda_data_void_ptr::
weak_cuda_data_void_ptr(
const cuda_data_void_ptr& ptr
) : num(ptr.num), pdata(ptr.pdata)
{
}
// ----------------------------------------------------------------------------------------
cuda_data_void_ptr weak_cuda_data_void_ptr::
lock() const
{
auto ptr = pdata.lock();
if (ptr)
{
cuda_data_void_ptr temp;
temp.pdata = ptr;
temp.num = num;
return temp;
}
else
{
return cuda_data_void_ptr();
}
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
cuda_data_void_ptr::
cuda_data_void_ptr(
size_t n
) : num(n)
{
if (n == 0)
return;
void* data = nullptr;
CHECK_CUDA(cudaMalloc(&data, n));
pdata.reset(data, [](void* ptr){
auto err = cudaFree(ptr);
if(err!=cudaSuccess)
std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
});
}
// ------------------------------------------------------------------------------------
void memcpy(
void* dest,
const cuda_data_void_ptr& src,
const size_t num
)
{
DLIB_ASSERT(num <= src.size());
if (src.size() != 0)
{
CHECK_CUDA(cudaMemcpy(dest, src.data(), num, cudaMemcpyDefault));
}
}
// ------------------------------------------------------------------------------------
void memcpy(
void* dest,
const cuda_data_void_ptr& src
)
{
memcpy(dest, src, src.size());
}
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr dest,
const void* src,
const size_t num
)
{
DLIB_ASSERT(num <= dest.size());
if (dest.size() != 0)
{
CHECK_CUDA(cudaMemcpy(dest.data(), src, num, cudaMemcpyDefault));
}
}
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr dest,
const void* src
)
{
memcpy(dest,src,dest.size());
}
// ------------------------------------------------------------------------------------
class cudnn_device_buffer
{
public:
// not copyable
cudnn_device_buffer(const cudnn_device_buffer&) = delete;
cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete;
cudnn_device_buffer()
{
buffers.resize(16);
}
~cudnn_device_buffer()
{
}
cuda_data_void_ptr get (
size_t size
)
{
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
// make room for more devices if needed
if (new_device_id >= (long)buffers.size())
buffers.resize(new_device_id+16);
// If we don't have a buffer already for this device then make one, or if it's too
// small, make a bigger one.
cuda_data_void_ptr buff = buffers[new_device_id].lock();
if (!buff || buff.size() < size)
{
buff = cuda_data_void_ptr(size);
buffers[new_device_id] = buff;
}
// Finally, return the buffer for the current device
return buff;
}
private:
std::vector<weak_cuda_data_void_ptr> buffers;
};
// ----------------------------------------------------------------------------------------
cuda_data_void_ptr device_global_buffer(size_t size)
{
thread_local cudnn_device_buffer buffer;
return buffer.get(size);
}
// ------------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuDA_DATA_PTR_CPP_