File size: 3,488 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
// Copyright (C) 2015 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuRAND_CPP_
#define DLIB_DNN_CuRAND_CPP_
#ifdef DLIB_USE_CUDA
#include "curand_dlibapi.h"
#include <curand.h>
#include "../string.h"
static const char* curand_get_error_string(curandStatus_t s)
{
switch(s)
{
case CURAND_STATUS_NOT_INITIALIZED:
return "CUDA Runtime API initialization failed.";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "The requested length must be a multiple of two.";
default:
return "A call to cuRAND failed";
}
}
// Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CURAND(call) \
do{ \
const curandStatus_t error = call; \
if (error != CURAND_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
throw dlib::curand_error(sout.str()); \
} \
}while(false)
namespace dlib
{
namespace cuda
{
// ----------------------------------------------------------------------------------------
curand_generator::
curand_generator(
unsigned long long seed
) : handle(nullptr)
{
curandGenerator_t gen;
CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
handle = gen;
CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed));
}
curand_generator::
~curand_generator()
{
if (handle)
{
curandDestroyGenerator((curandGenerator_t)handle);
}
}
void curand_generator::
fill_gaussian (
tensor& data,
float mean,
float stddev
)
{
if (data.size() == 0)
return;
CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle,
data.device(),
data.size(),
mean,
stddev));
}
void curand_generator::
fill_uniform (
tensor& data
)
{
if (data.size() == 0)
return;
CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
}
void curand_generator::
fill (
cuda_data_ptr<unsigned int>& data
)
{
if (data.size() == 0)
return;
CHECK_CURAND(curandGenerate((curandGenerator_t)handle, data, data.size()));
}
// -----------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuRAND_CPP_
|