|
<html><head><title>dlib C++ Library - cudnn_dlibapi.cpp</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// Copyright (C) 2015 Davis E. King ([email protected]) |
|
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license. |
|
</font><font color='#0000FF'>#ifndef</font> DLIB_DNN_CuDNN_CPP_ |
|
<font color='#0000FF'>#define</font> DLIB_DNN_CuDNN_CPP_ |
|
|
|
<font color='#0000FF'>#ifdef</font> DLIB_USE_CUDA |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cudnn_dlibapi.h.html'>cudnn_dlibapi.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='tensor.h.html'>tensor.h</a>" |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>cudnn.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>tuple<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>map<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>string<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>vector<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cuda_utils.h.html'>cuda_utils.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cpu_dlib.h.html'>cpu_dlib.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cuda_dlib.h.html'>cuda_dlib.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='tensor_tools.h.html'>tensor_tools.h</a>" |
|
|
|
<font color='#0000FF'>static</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font> <b><a name='cudnn_get_error_string'></a>cudnn_get_error_string</b><font face='Lucida Console'>(</font>cudnnStatus_t s<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>switch</font><font face='Lucida Console'>(</font>s<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>case</font> CUDNN_STATUS_NOT_INITIALIZED: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDA Runtime API initialization failed.</font>"; |
|
<font color='#0000FF'>case</font> CUDNN_STATUS_ALLOC_FAILED: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDA Resources could not be allocated.</font>"; |
|
<font color='#0000FF'>case</font> CUDNN_STATUS_BAD_PARAM: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_BAD_PARAM</font>"; |
|
<font color='#0000FF'>case</font> CUDNN_STATUS_EXECUTION_FAILED: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_EXECUTION_FAILED</font>"; |
|
<font color='#0000FF'>case</font> CUDNN_STATUS_NOT_SUPPORTED: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_NOT_SUPPORTED</font>"; |
|
<font color='#0000FF'>case</font> CUDNN_STATUS_ARCH_MISMATCH: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_ARCH_MISMATCH: Your GPU is too old and not supported by cuDNN</font>"; |
|
<font color='#0000FF'>default</font>: |
|
<font color='#0000FF'>return</font> "<font color='#CC0000'>A call to cuDNN failed</font>"; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Check the return value of a call to the cuDNN runtime for an error condition. |
|
</font><font color='#0000FF'>#define</font> CHECK_CUDNN<font face='Lucida Console'>(</font>call<font face='Lucida Console'>)</font> \ |
|
<font color='#0000FF'>do</font><b>{</b> \ |
|
<font color='#0000FF'>const</font> cudnnStatus_t error <font color='#5555FF'>=</font> call; \ |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>error <font color='#5555FF'>!</font><font color='#5555FF'>=</font> CUDNN_STATUS_SUCCESS<font face='Lucida Console'>)</font> \ |
|
<b>{</b> \ |
|
std::ostringstream sout; \ |
|
sout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Error while calling </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> #call <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> in file </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> __FILE__ <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>:</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> __LINE__ <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>. </font>";\ |
|
sout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>code: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> error <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>, reason: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>cudnn_get_error_string</font><font face='Lucida Console'>(</font>error<font face='Lucida Console'>)</font>;\ |
|
<font color='#0000FF'>throw</font> dlib::<font color='#BB00BB'>cudnn_error</font><font face='Lucida Console'>(</font>sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; \ |
|
<b>}</b> \ |
|
<b>}</b><font color='#0000FF'>while</font><font face='Lucida Console'>(</font><font color='#979000'>false</font><font face='Lucida Console'>)</font> |
|
|
|
|
|
<font color='#0000FF'>namespace</font> dlib |
|
<b>{</b> |
|
|
|
<font color='#0000FF'>namespace</font> cuda |
|
<b>{</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'>static</font> cudnnTensorDescriptor_t <b><a name='descriptor'></a>descriptor</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> t<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>t.<font color='#BB00BB'>get_cudnn_tensor_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>static</font> cudnnTensorDescriptor_t <b><a name='descriptor'></a>descriptor</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor_descriptor<font color='#5555FF'>&</font> t<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>t.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'>class</font> <b><a name='cudnn_context'></a>cudnn_context</b> |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
<font color='#009900'>// not copyable |
|
</font> <b><a name='cudnn_context'></a>cudnn_context</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_context<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
cudnn_context<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_context<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
|
|
<b><a name='cudnn_context'></a>cudnn_context</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
handles.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font><font color='#979000'>16</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
~<b><a name='cudnn_context'></a>cudnn_context</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> h : handles<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>h<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>cudnnDestroy</font><font face='Lucida Console'>(</font>h<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
cudnnHandle_t <b><a name='get_handle'></a>get_handle</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>int</u></font> new_device_id; |
|
<font color='#BB00BB'>CHECK_CUDA</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudaGetDevice</font><font face='Lucida Console'>(</font><font color='#5555FF'>&</font>new_device_id<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// make room for more devices if needed |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>new_device_id <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>handles.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
handles.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>new_device_id<font color='#5555FF'>+</font><font color='#979000'>16</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// If we don't have a handle already for this device then make one |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>handles[new_device_id]<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreate</font><font face='Lucida Console'>(</font><font color='#5555FF'>&</font>handles[new_device_id]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Finally, return the handle for the current device |
|
</font> <font color='#0000FF'>return</font> handles[new_device_id]; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>private</font>: |
|
|
|
std::vector<font color='#5555FF'><</font>cudnnHandle_t<font color='#5555FF'>></font> handles; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>static</font> cudnnHandle_t <b><a name='context'></a>context</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
thread_local cudnn_context c; |
|
<font color='#0000FF'>return</font> c.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'>class</font> <b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b> |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
<font color='#009900'>// not copyable |
|
</font> <b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_activation_descriptor<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
cudnn_activation_descriptor<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_activation_descriptor<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
|
|
<b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b><font face='Lucida Console'>(</font> |
|
cudnnActivationMode_t mode, |
|
cudnnNanPropagation_t reluNanOpt, |
|
<font color='#0000FF'><u>double</u></font> reluCeiling |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateActivationDescriptor</font><font face='Lucida Console'>(</font><font color='#5555FF'>&</font>handle<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetActivationDescriptor</font><font face='Lucida Console'>(</font>handle, mode, reluNanOpt, reluCeiling<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
~<b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>cudnnDestroyActivationDescriptor</font><font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
cudnnActivationDescriptor_t <b><a name='get_handle'></a>get_handle</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> handle; |
|
<b>}</b> |
|
<font color='#0000FF'>private</font>: |
|
cudnnActivationDescriptor_t handle; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>static</font> cudnnActivationDescriptor_t <b><a name='relu_activation_descriptor'></a>relu_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
thread_local cudnn_activation_descriptor <font color='#BB00BB'>des</font><font face='Lucida Console'>(</font>CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,<font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> des.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>static</font> cudnnActivationDescriptor_t <b><a name='sigmoid_activation_descriptor'></a>sigmoid_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
thread_local cudnn_activation_descriptor <font color='#BB00BB'>des</font><font face='Lucida Console'>(</font>CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,<font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> des.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>static</font> cudnnActivationDescriptor_t <b><a name='tanh_activation_descriptor'></a>tanh_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
thread_local cudnn_activation_descriptor <font color='#BB00BB'>des</font><font face='Lucida Console'>(</font>CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,<font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> des.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
tensor_descriptor:: |
|
<b><a name='tensor_descriptor'></a>tensor_descriptor</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> : handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<b>}</b> |
|
|
|
tensor_descriptor:: |
|
~<b><a name='tensor_descriptor'></a>tensor_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_descriptor:: |
|
<b><a name='set_size'></a>set_size</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font> n, |
|
<font color='#0000FF'><u>int</u></font> k, |
|
<font color='#0000FF'><u>int</u></font> nr, |
|
<font color='#0000FF'><u>int</u></font> nc |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>cudnnDestroyTensorDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>handle<font face='Lucida Console'>)</font>; |
|
handle <font color='#5555FF'>=</font> nullptr; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>n <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> nr <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> nc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> k <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cudnnTensorDescriptor_t h; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateTensorDescriptor</font><font face='Lucida Console'>(</font><font color='#5555FF'>&</font>h<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
handle <font color='#5555FF'>=</font> h; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetTensor4dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>handle, |
|
CUDNN_TENSOR_NCHW, |
|
CUDNN_DATA_FLOAT, |
|
n, |
|
k, |
|
nr, |
|
nc<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_descriptor:: |
|
<b><a name='get_size'></a>get_size</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&</font> n, |
|
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&</font> k, |
|
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&</font> nr, |
|
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&</font> nc |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>int</u></font> nStride, cStride, hStride, wStride; |
|
cudnnDataType_t datatype; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetTensor4dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>handle, |
|
<font color='#5555FF'>&</font>datatype, |
|
<font color='#5555FF'>&</font>n, |
|
<font color='#5555FF'>&</font>k, |
|
<font color='#5555FF'>&</font>nr, |
|
<font color='#5555FF'>&</font>nc, |
|
<font color='#5555FF'>&</font>nStride, |
|
<font color='#5555FF'>&</font>cStride, |
|
<font color='#5555FF'>&</font>hStride, |
|
<font color='#5555FF'>&</font>wStride<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
n <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='add'></a>add</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>float</u></font> beta, |
|
tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'><u>float</u></font> alpha, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>src, dest<font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> |
|
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> |
|
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> |
|
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> |
|
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>src,dest<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font> , |
|
"<font color='#CC0000'>\n\t dest.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t dest.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t dest.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t dest.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t src.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t src.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t src.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>\n\t src.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> beta <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Call the dlib function in this case since it's faster than the one that |
|
</font> <font color='#009900'>// comes with cuDNN (at least as of cuDNN v4). |
|
</font> <font color='#BB00BB'>add_scaled</font><font face='Lucida Console'>(</font>dest, alpha, src<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>add_cv_to_all_columns</font><font face='Lucida Console'>(</font>beta, dest, alpha, src<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnAddTensor</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='assign_conv_bias_gradient'></a>assign_conv_bias_gradient</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> grad, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
grad.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
grad.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
grad.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
grad.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gradient_input.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> grad.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gradient_input.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font> |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionBackwardBias</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_inference'></a>batch_normalize_inference</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps, |
|
resizable_tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gamma, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> beta, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> running_means, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> running_variances |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, beta<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_means<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_variances<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
eps <font color='#5555FF'>></font> <font color='#979000'>0</font>, |
|
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> eps |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardInference</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_BATCHNORM_PER_ACTIVATION, |
|
<font color='#5555FF'>&</font>in_scale, |
|
<font color='#5555FF'>&</font>out_scale, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>, |
|
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
eps<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize'></a>batch_normalize</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps, |
|
resizable_tensor<font color='#5555FF'>&</font> dest, |
|
resizable_tensor<font color='#5555FF'>&</font> means, |
|
resizable_tensor<font color='#5555FF'>&</font> invstds, |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> averaging_factor, |
|
resizable_tensor<font color='#5555FF'>&</font> running_means, |
|
resizable_tensor<font color='#5555FF'>&</font> running_variances, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gamma, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> beta |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> averaging_factor <font color='#5555FF'>&</font><font color='#5555FF'>&</font> averaging_factor <font color='#5555FF'><</font><font color='#5555FF'>=</font> <font color='#979000'>1</font>, "<font color='#CC0000'>averaging_factor: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> averaging_factor<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_means,means<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_variances,invstds<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
eps <font color='#5555FF'>></font> <font color='#979000'>0</font>, |
|
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> eps |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>; |
|
means.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>1</font>, src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
invstds.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>; |
|
running_means.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>; |
|
running_variances.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// cuDNN requires that running_means and running_variances be initialized to |
|
</font> <font color='#009900'>// some valid float values even if the averaging factor would have ignored |
|
</font> <font color='#009900'>// them. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>averaging_factor <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
running_means <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
running_variances <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardTraining</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_BATCHNORM_PER_ACTIVATION, |
|
<font color='#5555FF'>&</font>in_scale, |
|
<font color='#5555FF'>&</font>out_scale, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>, |
|
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
averaging_factor, |
|
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
eps, |
|
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_gradient'></a>batch_normalize_gradient</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> means, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> invstds, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gamma, |
|
tensor<font color='#5555FF'>&</font> src_grad, |
|
tensor<font color='#5555FF'>&</font> gamma_grad, |
|
tensor<font color='#5555FF'>&</font> beta_grad |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> num <font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font>src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font>src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>means.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>invstds.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>beta_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src_grad<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>eps <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale_params <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale_params <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationBackward</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_BATCHNORM_PER_ACTIVATION, |
|
<font color='#5555FF'>&</font>in_scale, |
|
<font color='#5555FF'>&</font>out_scale, |
|
<font color='#5555FF'>&</font>in_scale_params, |
|
<font color='#5555FF'>&</font>out_scale_params, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src_grad<font face='Lucida Console'>)</font>, |
|
src_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>, |
|
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
gamma_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
beta_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
eps, |
|
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_conv_inference'></a>batch_normalize_conv_inference</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps, |
|
resizable_tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gamma, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> beta, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> running_means, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> running_variances |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, beta<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_means<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_variances<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
eps <font color='#5555FF'>></font> <font color='#979000'>0</font>, |
|
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_means.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_means.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nrunning_variances.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> running_variances.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> eps |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardInference</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_BATCHNORM_SPATIAL, |
|
<font color='#5555FF'>&</font>in_scale, |
|
<font color='#5555FF'>&</font>out_scale, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>, |
|
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
eps<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_conv'></a>batch_normalize_conv</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps, |
|
resizable_tensor<font color='#5555FF'>&</font> dest, |
|
resizable_tensor<font color='#5555FF'>&</font> means, |
|
resizable_tensor<font color='#5555FF'>&</font> invstds, |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> averaging_factor, |
|
resizable_tensor<font color='#5555FF'>&</font> running_means, |
|
resizable_tensor<font color='#5555FF'>&</font> running_variances, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gamma, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> beta |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> averaging_factor <font color='#5555FF'>&</font><font color='#5555FF'>&</font> averaging_factor <font color='#5555FF'><</font><font color='#5555FF'>=</font> <font color='#979000'>1</font>, "<font color='#CC0000'>averaging_factor: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> averaging_factor<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_means,means<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_variances,invstds<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
eps <font color='#5555FF'>></font> <font color='#979000'>0</font>, |
|
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> eps |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>; |
|
means.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>1</font>, src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
invstds.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>; |
|
running_means.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>; |
|
running_variances.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// cuDNN requires that running_means and running_variances be initialized to |
|
</font> <font color='#009900'>// some valid float values even if the averaging factor would have ignored |
|
</font> <font color='#009900'>// them. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>averaging_factor <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
running_means <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
running_variances <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardTraining</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_BATCHNORM_SPATIAL, |
|
<font color='#5555FF'>&</font>in_scale, |
|
<font color='#5555FF'>&</font>out_scale, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>, |
|
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
averaging_factor, |
|
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
eps, |
|
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_conv_gradient'></a>batch_normalize_conv_gradient</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> means, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> invstds, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gamma, |
|
tensor<font color='#5555FF'>&</font> src_grad, |
|
tensor<font color='#5555FF'>&</font> gamma_grad, |
|
tensor<font color='#5555FF'>&</font> beta_grad |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>means.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>invstds.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>beta_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src_grad<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>eps <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale_params <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale_params <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationBackward</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_BATCHNORM_SPATIAL, |
|
<font color='#5555FF'>&</font>in_scale, |
|
<font color='#5555FF'>&</font>out_scale, |
|
<font color='#5555FF'>&</font>in_scale_params, |
|
<font color='#5555FF'>&</font>out_scale_params, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src_grad<font face='Lucida Console'>)</font>, |
|
src_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>, |
|
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
gamma_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
beta_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
eps, |
|
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> <font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
tensor_conv:: |
|
<b><a name='tensor_conv'></a>tensor_conv</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> : |
|
filter_handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>, |
|
conv_handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>, |
|
forward_algo<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, |
|
backward_data_algo<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, |
|
backward_filters_algo<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv:: |
|
<b><a name='clear'></a>clear</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>filter_handle<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>cudnnDestroyFilterDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>conv_handle<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>cudnnDestroyConvolutionDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle<font face='Lucida Console'>)</font>; |
|
filter_handle <font color='#5555FF'>=</font> nullptr; |
|
conv_handle <font color='#5555FF'>=</font> nullptr; |
|
out_num_samples <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
out_k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
out_nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
out_nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
stride_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
stride_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
padding_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
padding_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
data_num_samples <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
data_k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
data_nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
data_nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
filters_num_samples <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
filters_k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
filters_nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
filters_nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
forward_algo <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
backward_data_algo <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
backward_filters_algo <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
forward_workspace_size_in_bytes <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
backward_data_workspace_size_in_bytes <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
backward_filters_workspace_size_in_bytes <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
forward_workspace.<font color='#BB00BB'>reset</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
backward_data_workspace.<font color='#BB00BB'>reset</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
backward_filters_workspace.<font color='#BB00BB'>reset</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Given an array of cudnn algorithm performance results, like |
|
</font> <font color='#009900'>// cudnnConvolutionFwdAlgoPerf_t, pick the best one to use. |
|
</font> <font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T<font color='#5555FF'>></font> |
|
<b><a name='decltype'></a>decltype</b><font face='Lucida Console'>(</font>std::declval<font color='#5555FF'><</font>T<font color='#5555FF'>></font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.algo<font face='Lucida Console'>)</font> <b><a name='pick_best_algorithm'></a>pick_best_algorithm</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>T<font color='#5555FF'>></font> <font color='#5555FF'>&</font>perf_results<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#5555FF'>!</font>perf_results.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font>perf_results[<font color='#979000'>0</font>].status<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> perf_results[<font color='#979000'>0</font>].algo; |
|
|
|
<font color='#009900'>// Otherwise we find the algorithm that has a good status and uses the least amount |
|
</font> <font color='#009900'>// of memory. |
|
</font> <font color='#0000FF'><u>size_t</u></font> best_memory <font color='#5555FF'>=</font> std::numeric_limits<font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>></font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>decltype</font><font face='Lucida Console'>(</font>std::declval<font color='#5555FF'><</font>T<font color='#5555FF'>></font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.algo<font face='Lucida Console'>)</font> best_alg; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> perf : perf_results<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>perf.status <font color='#5555FF'>=</font><font color='#5555FF'>=</font> CUDNN_STATUS_SUCCESS <font color='#5555FF'>&</font><font color='#5555FF'>&</font> perf.memory <font color='#5555FF'><</font> best_memory<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
best_memory <font color='#5555FF'>=</font> perf.memory; |
|
best_alg <font color='#5555FF'>=</font> perf.algo; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>return</font> best_alg; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv:: |
|
<b><a name='select_best_algorithms'></a>select_best_algorithms</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> tensor_descriptor<font color='#5555FF'>&</font> dest_desc |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Calling the cuDNN "find the best algorithm" functions are really slow. So we keep a |
|
</font> <font color='#009900'>// cache that tells us what method was best for a particular configuration. |
|
</font> thread_local std::map<font color='#5555FF'><</font>std::tuple<font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>long</u></font>,<font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font>, |
|
std::tuple<font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> config_to_algo_cache; |
|
|
|
<font color='#009900'>// If we have already found good algorithms for this setting then just pull them from |
|
</font> <font color='#009900'>// the cache. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> cache_key <font color='#5555FF'>=</font> std::<font color='#BB00BB'>make_tuple</font><font face='Lucida Console'>(</font>stride_y, stride_x, padding_y, padding_x, filters_nr, filters_nc<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> iter <font color='#5555FF'>=</font> config_to_algo_cache.<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font>cache_key<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>iter <font color='#5555FF'>!</font><font color='#5555FF'>=</font> config_to_algo_cache.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::<font color='#BB00BB'>tie</font><font face='Lucida Console'>(</font>forward_algo, backward_data_algo, backward_filters_algo<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> iter<font color='#5555FF'>-</font><font color='#5555FF'>></font>second; |
|
<font color='#0000FF'>return</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#009900'>// Pick which forward algorithm we will use and allocate the necessary |
|
</font> <font color='#009900'>// workspace buffer. |
|
</font> cudnnConvolutionFwdAlgo_t forward_best_algo; |
|
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>8</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>int</u></font> num_possible_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionForwardAlgorithmMaxCount</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#5555FF'>&</font>num_possible_algorithms<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
std::vector<font color='#5555FF'><</font>cudnnConvolutionFwdAlgoPerf_t<font color='#5555FF'>></font> <font color='#BB00BB'>perf_results</font><font face='Lucida Console'>(</font>num_possible_algorithms<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>int</u></font> num_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnFindConvolutionForwardAlgorithm</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
num_possible_algorithms, |
|
<font color='#5555FF'>&</font>num_algorithms, |
|
perf_results.<font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
perf_results.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>num_algorithms<font face='Lucida Console'>)</font>; |
|
forward_best_algo <font color='#5555FF'>=</font> <font color='#BB00BB'>pick_best_algorithm</font><font face='Lucida Console'>(</font>perf_results<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>#else</font> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionForwardAlgorithm</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, |
|
std::numeric_limits<font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>></font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>forward_best_algo<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>#endif</font> |
|
forward_algo <font color='#5555FF'>=</font> forward_best_algo; |
|
|
|
|
|
|
|
<font color='#009900'>// Pick which backward data algorithm we will use and allocate the |
|
</font> <font color='#009900'>// necessary workspace buffer. |
|
</font> cudnnConvolutionBwdDataAlgo_t backward_data_best_algo; |
|
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>8</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>int</u></font> num_possible_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterAlgorithmMaxCount</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#5555FF'>&</font>num_possible_algorithms<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
std::vector<font color='#5555FF'><</font>cudnnConvolutionBwdDataAlgoPerf_t<font color='#5555FF'>></font> <font color='#BB00BB'>perf_results</font><font face='Lucida Console'>(</font>num_possible_algorithms<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>int</u></font> num_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnFindConvolutionBackwardDataAlgorithm</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
num_possible_algorithms, |
|
<font color='#5555FF'>&</font>num_algorithms, |
|
perf_results.<font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
perf_results.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>num_algorithms<font face='Lucida Console'>)</font>; |
|
backward_data_best_algo <font color='#5555FF'>=</font> <font color='#BB00BB'>pick_best_algorithm</font><font face='Lucida Console'>(</font>perf_results<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>#else</font> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardDataAlgorithm</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE, |
|
std::numeric_limits<font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>></font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>backward_data_best_algo<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>#endif</font> |
|
backward_data_algo <font color='#5555FF'>=</font> backward_data_best_algo; |
|
|
|
|
|
|
|
|
|
<font color='#009900'>// Pick which backward filters algorithm we will use and allocate the |
|
</font> <font color='#009900'>// necessary workspace buffer. |
|
</font> cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo; |
|
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>8</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>int</u></font> num_possible_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterAlgorithmMaxCount</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#5555FF'>&</font>num_possible_algorithms<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
std::vector<font color='#5555FF'><</font>cudnnConvolutionBwdFilterAlgoPerf_t<font color='#5555FF'>></font> <font color='#BB00BB'>perf_results</font><font face='Lucida Console'>(</font>num_possible_algorithms<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>int</u></font> num_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnFindConvolutionBackwardFilterAlgorithm</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
num_possible_algorithms, |
|
<font color='#5555FF'>&</font>num_algorithms, |
|
perf_results.<font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
perf_results.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>num_algorithms<font face='Lucida Console'>)</font>; |
|
backward_filters_best_algo <font color='#5555FF'>=</font> <font color='#BB00BB'>pick_best_algorithm</font><font face='Lucida Console'>(</font>perf_results<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>#else</font> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterAlgorithm</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE, |
|
std::numeric_limits<font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>></font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>backward_filters_best_algo<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>#endif</font> |
|
|
|
<font color='#009900'>// cuDNN 5.1 has a bug that causes |
|
</font> <font color='#009900'>// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd |
|
</font> <font color='#009900'>// algorithm even for cases where cuDNN doesn't support it, leading to |
|
</font> <font color='#009900'>// incorrect outputs. So here we check if we are in a case where winograd |
|
</font> <font color='#009900'>// isn't supported and manually overrule |
|
</font> <font color='#009900'>// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe |
|
</font> <font color='#009900'>// algorithm. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#5555FF'>!</font><font face='Lucida Console'>(</font>stride_x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> stride_y <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> <font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>filters_nr<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>3</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font>filters_nc<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>3</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font face='Lucida Console'>(</font>filters_nr<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>5</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font>filters_nc<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>5</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
backward_filters_best_algo <font color='#5555FF'>=</font> CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; |
|
<b>}</b> |
|
backward_filters_algo <font color='#5555FF'>=</font> backward_filters_best_algo; |
|
|
|
|
|
<font color='#009900'>// Save this algorithm selection in the cache |
|
</font> config_to_algo_cache[cache_key] <font color='#5555FF'>=</font> std::<font color='#BB00BB'>make_tuple</font><font face='Lucida Console'>(</font>forward_algo, backward_data_algo, backward_filters_algo<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv:: |
|
<b><a name='setup'></a>setup</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> filters, |
|
<font color='#0000FF'><u>int</u></font> stride_y_, |
|
<font color='#0000FF'><u>int</u></font> stride_x_, |
|
<font color='#0000FF'><u>int</u></font> padding_y_, |
|
<font color='#0000FF'><u>int</u></font> padding_x_ |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// if the last call to setup gave the same exact settings then don't do |
|
</font> <font color='#009900'>// anything. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>data_num_samples <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
data_k <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
data_nr <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
data_nc <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
stride_y_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_y <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
stride_x_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_x <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
padding_y_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_y <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
padding_x_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_x <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
filters_num_samples <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
filters_k <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
filters_nr <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
filters_nc <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>try</font> |
|
<b>{</b> |
|
stride_y <font color='#5555FF'>=</font> stride_y_; |
|
stride_x <font color='#5555FF'>=</font> stride_x_; |
|
padding_y <font color='#5555FF'>=</font> padding_y_; |
|
padding_x <font color='#5555FF'>=</font> padding_x_; |
|
data_num_samples <font color='#5555FF'>=</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_k <font color='#5555FF'>=</font> data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_nr <font color='#5555FF'>=</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_nc <font color='#5555FF'>=</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
filters_num_samples <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
filters_k <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
filters_nr <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
filters_nc <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateFilterDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnFilterDescriptor_t<font color='#5555FF'>*</font><font face='Lucida Console'>)</font><font color='#5555FF'>&</font>filter_handle<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetFilter4dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
CUDNN_DATA_FLOAT, |
|
CUDNN_TENSOR_NCHW, |
|
filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateConvolutionDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font color='#5555FF'>*</font><font face='Lucida Console'>)</font><font color='#5555FF'>&</font>conv_handle<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>6</font> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetConvolution2dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
padding_y, <font color='#009900'>// vertical padding |
|
</font> padding_x, <font color='#009900'>// horizontal padding |
|
</font> stride_y, |
|
stride_x, |
|
<font color='#979000'>1</font>, <font color='#979000'>1</font>, <font color='#009900'>// must be 1,1 |
|
</font> CUDNN_CROSS_CORRELATION, |
|
CUDNN_DATA_FLOAT<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// could also be CUDNN_CONVOLUTION |
|
</font><font color='#0000FF'>#else</font> |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetConvolution2dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
padding_y, <font color='#009900'>// vertical padding |
|
</font> padding_x, <font color='#009900'>// horizontal padding |
|
</font> stride_y, |
|
stride_x, |
|
<font color='#979000'>1</font>, <font color='#979000'>1</font>, <font color='#009900'>// must be 1,1 |
|
</font> CUDNN_CROSS_CORRELATION<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// could also be CUDNN_CONVOLUTION |
|
</font><font color='#0000FF'>#endif</font> |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolution2dForwardOutputDim</font><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font color='#5555FF'>&</font>out_num_samples, |
|
<font color='#5555FF'>&</font>out_k, |
|
<font color='#5555FF'>&</font>out_nr, |
|
<font color='#5555FF'>&</font>out_nc<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
tensor_descriptor dest_desc; |
|
dest_desc.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font>out_num_samples,out_k,out_nr,out_nc<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>select_best_algorithms</font><font face='Lucida Console'>(</font>data, dest_desc<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionForwardWorkspaceSize</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font>cudnnConvolutionFwdAlgo_t<font face='Lucida Console'>)</font>forward_algo, |
|
<font color='#5555FF'>&</font>forward_workspace_size_in_bytes<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardDataWorkspaceSize</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font>cudnnConvolutionBwdDataAlgo_t<font face='Lucida Console'>)</font>backward_data_algo, |
|
<font color='#5555FF'>&</font>backward_data_workspace_size_in_bytes<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterWorkspaceSize</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
<font face='Lucida Console'>(</font>cudnnConvolutionBwdFilterAlgo_t<font face='Lucida Console'>)</font>backward_filters_algo, |
|
<font color='#5555FF'>&</font>backward_filters_workspace_size_in_bytes<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>...<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>throw</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
tensor_conv:: |
|
~<b><a name='tensor_conv'></a>tensor_conv</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv::<b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output, |
|
resizable_tensor<font color='#5555FF'>&</font> output, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> filters |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>stride_y <font color='#5555FF'>></font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> stride_x <font color='#5555FF'>></font> <font color='#979000'>0</font>, "<font color='#CC0000'>You must call setup() before calling this function</font>"<font face='Lucida Console'>)</font>; |
|
|
|
output.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font>out_num_samples, out_k, out_nr, out_nc<font face='Lucida Console'>)</font>; |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>add_to_output, <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font>tensor<font color='#5555FF'>&</font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>output<font face='Lucida Console'>)</font>, data, filters<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv::<b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output, |
|
tensor<font color='#5555FF'>&</font> output, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> filters |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>output,data<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>output,filters<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>stride_y <font color='#5555FF'>></font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> stride_x <font color='#5555FF'>></font> <font color='#979000'>0</font>, "<font color='#CC0000'>You must call setup() before calling this function</font>"<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x, |
|
"<font color='#CC0000'>Filter windows must be small enough to fit into the padded image.</font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t filters.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t data.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t padding_x: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_x |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y, |
|
"<font color='#CC0000'>Filter windows must be small enough to fit into the padded image.</font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t filters.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t data.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t padding_y: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_y |
|
<font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,out_num_samples <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font color='#5555FF'>+</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font><font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y<font color='#5555FF'>-</font>filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_y<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font color='#5555FF'>+</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font><font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x<font color='#5555FF'>-</font>filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_x<font face='Lucida Console'>)</font>; |
|
|
|
|
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> add_to_output ? <font color='#979000'>1</font> : <font color='#979000'>0</font>; |
|
|
|
<font color='#009900'>// Since cudnnConvolutionForward() is an asynchronous call, we need to hold a |
|
</font> <font color='#009900'>// reference to the workspace buffer so we can be sure it isn't reallocated |
|
</font> <font color='#009900'>// while the function is still executing on the device. But each time we come |
|
</font> <font color='#009900'>// here, we make sure to grab the latest workspace buffer so that, globally, we |
|
</font> <font color='#009900'>// minimize the number of such buffers. |
|
</font> forward_workspace <font color='#5555FF'>=</font> <font color='#BB00BB'>device_global_buffer</font><font face='Lucida Console'>(</font>forward_workspace_size_in_bytes<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionForward</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
data.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
filters.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font face='Lucida Console'>(</font>cudnnConvolutionFwdAlgo_t<font face='Lucida Console'>)</font>forward_algo, |
|
forward_workspace, |
|
forward_workspace_size_in_bytes, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>output<font face='Lucida Console'>)</font>, |
|
output.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv::<b><a name='get_gradient_for_data'></a>get_gradient_for_data</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> filters, |
|
tensor<font color='#5555FF'>&</font> data_gradient |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> add_to_output ? <font color='#979000'>1</font> : <font color='#979000'>0</font>; |
|
|
|
<font color='#009900'>// Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a |
|
</font> <font color='#009900'>// reference to the workspace buffer so we can be sure it isn't reallocated |
|
</font> <font color='#009900'>// while the function is still executing on the device. But each time we come |
|
</font> <font color='#009900'>// here, we make sure to grab the latest workspace buffer so that, globally, we |
|
</font> <font color='#009900'>// minimize the number of such buffers. |
|
</font> backward_data_workspace <font color='#5555FF'>=</font> <font color='#BB00BB'>device_global_buffer</font><font face='Lucida Console'>(</font>backward_data_workspace_size_in_bytes<font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionBackwardData</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
filters.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font face='Lucida Console'>(</font>cudnnConvolutionBwdDataAlgo_t<font face='Lucida Console'>)</font>backward_data_algo, |
|
backward_data_workspace, |
|
backward_data_workspace_size_in_bytes, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data_gradient<font face='Lucida Console'>)</font>, |
|
data_gradient.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> tensor_conv:: |
|
<b><a name='get_gradient_for_filters'></a>get_gradient_for_filters</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> data, |
|
tensor<font color='#5555FF'>&</font> filters_gradient |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> add_to_output ? <font color='#979000'>1</font> : <font color='#979000'>0</font>; |
|
|
|
<font color='#009900'>// Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a |
|
</font> <font color='#009900'>// reference to the workspace buffer so we can be sure it isn't reallocated |
|
</font> <font color='#009900'>// while the function is still executing on the device. But each time we come |
|
</font> <font color='#009900'>// here, we make sure to grab the latest workspace buffer so that, globally, we |
|
</font> <font color='#009900'>// minimize the number of such buffers. |
|
</font> backward_filters_workspace <font color='#5555FF'>=</font> <font color='#BB00BB'>device_global_buffer</font><font face='Lucida Console'>(</font>backward_filters_workspace_size_in_bytes<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionBackwardFilter</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>, |
|
data.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle, |
|
<font face='Lucida Console'>(</font>cudnnConvolutionBwdFilterAlgo_t<font face='Lucida Console'>)</font>backward_filters_algo, |
|
backward_filters_workspace, |
|
backward_filters_workspace_size_in_bytes, |
|
<font color='#5555FF'>&</font>beta, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle, |
|
filters_gradient.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> <font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
pooling::<b><a name='pooling'></a>pooling</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> : handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>,window_height<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,window_width<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,stride_y<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,stride_x<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,padding_y<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, padding_x<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<b>}</b> |
|
|
|
pooling::~<b><a name='pooling'></a>pooling</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> pooling:: |
|
<b><a name='clear'></a>clear</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>cudnnDestroyPoolingDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle<font face='Lucida Console'>)</font>; |
|
handle <font color='#5555FF'>=</font> nullptr; |
|
window_height <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
window_width <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
stride_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
stride_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
padding_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
padding_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> pooling:: |
|
<b><a name='setup_max_pooling'></a>setup_max_pooling</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font> window_height_, |
|
<font color='#0000FF'><u>int</u></font> window_width_, |
|
<font color='#0000FF'><u>int</u></font> stride_y_, |
|
<font color='#0000FF'><u>int</u></font> stride_x_, |
|
<font color='#0000FF'><u>int</u></font> padding_y_, |
|
<font color='#0000FF'><u>int</u></font> padding_x_ |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>setup</font><font face='Lucida Console'>(</font>window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_MAX<font face='Lucida Console'>)</font>; |
|
do_max_pooling <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> pooling:: |
|
<b><a name='setup_avg_pooling'></a>setup_avg_pooling</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font> window_height_, |
|
<font color='#0000FF'><u>int</u></font> window_width_, |
|
<font color='#0000FF'><u>int</u></font> stride_y_, |
|
<font color='#0000FF'><u>int</u></font> stride_x_, |
|
<font color='#0000FF'><u>int</u></font> padding_y_, |
|
<font color='#0000FF'><u>int</u></font> padding_x_ |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>setup</font><font face='Lucida Console'>(</font>window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING<font face='Lucida Console'>)</font>; |
|
do_max_pooling <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> pooling:: |
|
<b><a name='setup'></a>setup</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font> window_height_, |
|
<font color='#0000FF'><u>int</u></font> window_width_, |
|
<font color='#0000FF'><u>int</u></font> stride_y_, |
|
<font color='#0000FF'><u>int</u></font> stride_x_, |
|
<font color='#0000FF'><u>int</u></font> padding_y_, |
|
<font color='#0000FF'><u>int</u></font> padding_x_, |
|
<font color='#0000FF'><u>int</u></font> pooling_mode |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font> <font face='Lucida Console'>(</font>window_height_ <font color='#5555FF'>></font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> window_width_ <font color='#5555FF'>></font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
stride_y_ <font color='#5555FF'>></font> <font color='#979000'>0</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> stride_x_ <font color='#5555FF'>></font> <font color='#979000'>0</font> , |
|
"<font color='#CC0000'>window_height_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_height_ |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t\n window_width_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_width_ |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t\n stride_y_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> stride_y_ |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t\n stride_x_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> stride_x_ <font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> <font color='#979000'>0</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> padding_y_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> padding_y_ <font color='#5555FF'><</font> window_height_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#979000'>0</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> padding_x_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> padding_x_ <font color='#5555FF'><</font> window_width_, |
|
"<font color='#CC0000'>window_height_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_height_ |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t\n window_width_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_width_ |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t\n padding_y_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_y_ |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t\n padding_x_: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_x_ <font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>window_height <font color='#5555FF'>=</font><font color='#5555FF'>=</font> window_height_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
window_width <font color='#5555FF'>=</font><font color='#5555FF'>=</font> window_width_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
stride_y <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_y_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
stride_x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_x_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
padding_y <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_y_ <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
padding_x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_x_ |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>try</font> |
|
<b>{</b> |
|
window_height <font color='#5555FF'>=</font> window_height_; |
|
window_width <font color='#5555FF'>=</font> window_width_; |
|
stride_x <font color='#5555FF'>=</font> stride_x_; |
|
stride_y <font color='#5555FF'>=</font> stride_y_; |
|
padding_y <font color='#5555FF'>=</font> padding_y_; |
|
padding_x <font color='#5555FF'>=</font> padding_x_; |
|
cudnnPoolingDescriptor_t poolingDesc; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreatePoolingDescriptor</font><font face='Lucida Console'>(</font><font color='#5555FF'>&</font>poolingDesc<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
handle <font color='#5555FF'>=</font> poolingDesc; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetPooling2dDescriptor</font><font face='Lucida Console'>(</font>poolingDesc, |
|
<font face='Lucida Console'>(</font>cudnnPoolingMode_t<font face='Lucida Console'>)</font>pooling_mode, |
|
CUDNN_PROPAGATE_NAN, |
|
window_height, |
|
window_width, |
|
padding_y, |
|
padding_x, |
|
stride_y, |
|
stride_x<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>...<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>throw</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> pooling:: |
|
<b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font> |
|
resizable_tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>window_width <font color='#5555FF'><</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x, |
|
"<font color='#CC0000'>Pooling windows must be small enough to fit into the padded image.</font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t window_width: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_width |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t src.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t padding_x: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_x |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>window_height <font color='#5555FF'><</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y, |
|
"<font color='#CC0000'>Pooling windows must be small enough to fit into the padded image.</font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t window_height: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_height |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t src.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\n\t padding_y: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_y |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'><u>int</u></font> outN; |
|
<font color='#0000FF'><u>int</u></font> outC; |
|
<font color='#0000FF'><u>int</u></font> outH; |
|
<font color='#0000FF'><u>int</u></font> outW; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetPooling2dForwardOutputDim</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>outN, |
|
<font color='#5555FF'>&</font>outC, |
|
<font color='#5555FF'>&</font>outH, |
|
<font color='#5555FF'>&</font>outW<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
dest.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font>outN,outC,outH,outW<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>+</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y <font color='#5555FF'>-</font> window_height<font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_y, |
|
"<font color='#CC0000'>\n stride_y: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> stride_y <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n padding_y: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_y <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n window_height: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_height <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n src.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n dest.nr(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n src.nr()/stride_y: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_y<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>+</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x <font color='#5555FF'>-</font> window_width<font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_x, |
|
"<font color='#CC0000'>\n stride_x: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> stride_x <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n padding_x: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> padding_x <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n window_width: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> window_width <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n src.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n dest.nc(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\n src.nc()/stride_x: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_x<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnPoolingForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> pooling::<b><a name='get_gradient'></a>get_gradient</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src, |
|
tensor<font color='#5555FF'>&</font> grad |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input,dest<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>src,grad<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnPoolingBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> <font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='softmax'></a>softmax</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_SOFTMAX_ACCURATE, |
|
CUDNN_SOFTMAX_MODE_CHANNEL, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='softmax_gradient'></a>softmax_gradient</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> grad, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_SOFTMAX_ACCURATE, |
|
CUDNN_SOFTMAX_MODE_CHANNEL, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> <font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='softmax_all'></a>softmax_all</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_SOFTMAX_ACCURATE, |
|
CUDNN_SOFTMAX_MODE_INSTANCE, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='softmax_all_gradient'></a>softmax_all_gradient</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> grad, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
CUDNN_SOFTMAX_ACCURATE, |
|
CUDNN_SOFTMAX_MODE_INSTANCE, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> <font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='sigmoid'></a>sigmoid</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>sigmoid_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='sigmoid_gradient'></a>sigmoid_gradient</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> grad, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>sigmoid_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='relu'></a>relu</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>relu_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='relu_gradient'></a>relu_gradient</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> grad, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>relu_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='tanh'></a>tanh</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> src |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>tanh_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>, |
|
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='tanh_gradient'></a>tanh_gradient</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> grad, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> dest, |
|
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> gradient_input |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> |
|
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>; |
|
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>tanh_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>alpha, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>, |
|
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>, |
|
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
<font color='#5555FF'>&</font>beta, |
|
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>, |
|
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ------------------------------------------------------------------------------------ |
|
</font> <b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_USE_CUDA |
|
</font> |
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNN_CuDNN_CPP_ |
|
</font> |
|
|
|
|
|
</pre></body></html> |