Aging_MouthReplace / dlibs /docs /dlib /cuda /cudnn_dlibapi.cpp.html
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
230 kB
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - cudnn_dlibapi.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2015 Davis E. King ([email protected])
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license.
</font><font color='#0000FF'>#ifndef</font> DLIB_DNN_CuDNN_CPP_
<font color='#0000FF'>#define</font> DLIB_DNN_CuDNN_CPP_
<font color='#0000FF'>#ifdef</font> DLIB_USE_CUDA
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cudnn_dlibapi.h.html'>cudnn_dlibapi.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='tensor.h.html'>tensor.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>cudnn.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>tuple<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>map<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>string<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>vector<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cuda_utils.h.html'>cuda_utils.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cpu_dlib.h.html'>cpu_dlib.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cuda_dlib.h.html'>cuda_dlib.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='tensor_tools.h.html'>tensor_tools.h</a>"
<font color='#0000FF'>static</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font> <b><a name='cudnn_get_error_string'></a>cudnn_get_error_string</b><font face='Lucida Console'>(</font>cudnnStatus_t s<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>switch</font><font face='Lucida Console'>(</font>s<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>case</font> CUDNN_STATUS_NOT_INITIALIZED:
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDA Runtime API initialization failed.</font>";
<font color='#0000FF'>case</font> CUDNN_STATUS_ALLOC_FAILED:
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDA Resources could not be allocated.</font>";
<font color='#0000FF'>case</font> CUDNN_STATUS_BAD_PARAM:
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_BAD_PARAM</font>";
<font color='#0000FF'>case</font> CUDNN_STATUS_EXECUTION_FAILED:
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_EXECUTION_FAILED</font>";
<font color='#0000FF'>case</font> CUDNN_STATUS_NOT_SUPPORTED:
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_NOT_SUPPORTED</font>";
<font color='#0000FF'>case</font> CUDNN_STATUS_ARCH_MISMATCH:
<font color='#0000FF'>return</font> "<font color='#CC0000'>CUDNN_STATUS_ARCH_MISMATCH: Your GPU is too old and not supported by cuDNN</font>";
<font color='#0000FF'>default</font>:
<font color='#0000FF'>return</font> "<font color='#CC0000'>A call to cuDNN failed</font>";
<b>}</b>
<b>}</b>
<font color='#009900'>// Check the return value of a call to the cuDNN runtime for an error condition.
</font><font color='#0000FF'>#define</font> CHECK_CUDNN<font face='Lucida Console'>(</font>call<font face='Lucida Console'>)</font> \
<font color='#0000FF'>do</font><b>{</b> \
<font color='#0000FF'>const</font> cudnnStatus_t error <font color='#5555FF'>=</font> call; \
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>error <font color='#5555FF'>!</font><font color='#5555FF'>=</font> CUDNN_STATUS_SUCCESS<font face='Lucida Console'>)</font> \
<b>{</b> \
std::ostringstream sout; \
sout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Error while calling </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> #call <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> in file </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> __FILE__ <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>:</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> __LINE__ <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>. </font>";\
sout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>code: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> error <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>, reason: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>cudnn_get_error_string</font><font face='Lucida Console'>(</font>error<font face='Lucida Console'>)</font>;\
<font color='#0000FF'>throw</font> dlib::<font color='#BB00BB'>cudnn_error</font><font face='Lucida Console'>(</font>sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; \
<b>}</b> \
<b>}</b><font color='#0000FF'>while</font><font face='Lucida Console'>(</font><font color='#979000'>false</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>namespace</font> dlib
<b>{</b>
<font color='#0000FF'>namespace</font> cuda
<b>{</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>static</font> cudnnTensorDescriptor_t <b><a name='descriptor'></a>descriptor</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> t<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>t.<font color='#BB00BB'>get_cudnn_tensor_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>static</font> cudnnTensorDescriptor_t <b><a name='descriptor'></a>descriptor</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor_descriptor<font color='#5555FF'>&amp;</font> t<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>t.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>class</font> <b><a name='cudnn_context'></a>cudnn_context</b>
<b>{</b>
<font color='#0000FF'>public</font>:
<font color='#009900'>// not copyable
</font> <b><a name='cudnn_context'></a>cudnn_context</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_context<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
cudnn_context<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_context<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
<b><a name='cudnn_context'></a>cudnn_context</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
handles.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font><font color='#979000'>16</font><font face='Lucida Console'>)</font>;
<b>}</b>
~<b><a name='cudnn_context'></a>cudnn_context</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> h : handles<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>h<font face='Lucida Console'>)</font>
<font color='#BB00BB'>cudnnDestroy</font><font face='Lucida Console'>(</font>h<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
cudnnHandle_t <b><a name='get_handle'></a>get_handle</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'><u>int</u></font> new_device_id;
<font color='#BB00BB'>CHECK_CUDA</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudaGetDevice</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font>new_device_id<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// make room for more devices if needed
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>new_device_id <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>handles.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
handles.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>new_device_id<font color='#5555FF'>+</font><font color='#979000'>16</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// If we don't have a handle already for this device then make one
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>handles[new_device_id]<font face='Lucida Console'>)</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreate</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font>handles[new_device_id]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Finally, return the handle for the current device
</font> <font color='#0000FF'>return</font> handles[new_device_id];
<b>}</b>
<font color='#0000FF'>private</font>:
std::vector<font color='#5555FF'>&lt;</font>cudnnHandle_t<font color='#5555FF'>&gt;</font> handles;
<b>}</b>;
<font color='#0000FF'>static</font> cudnnHandle_t <b><a name='context'></a>context</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
thread_local cudnn_context c;
<font color='#0000FF'>return</font> c.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>class</font> <b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b>
<b>{</b>
<font color='#0000FF'>public</font>:
<font color='#009900'>// not copyable
</font> <b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_activation_descriptor<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
cudnn_activation_descriptor<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnn_activation_descriptor<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
<b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b><font face='Lucida Console'>(</font>
cudnnActivationMode_t mode,
cudnnNanPropagation_t reluNanOpt,
<font color='#0000FF'><u>double</u></font> reluCeiling
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateActivationDescriptor</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font>handle<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetActivationDescriptor</font><font face='Lucida Console'>(</font>handle, mode, reluNanOpt, reluCeiling<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
~<b><a name='cudnn_activation_descriptor'></a>cudnn_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>cudnnDestroyActivationDescriptor</font><font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font>;
<b>}</b>
cudnnActivationDescriptor_t <b><a name='get_handle'></a>get_handle</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> handle;
<b>}</b>
<font color='#0000FF'>private</font>:
cudnnActivationDescriptor_t handle;
<b>}</b>;
<font color='#0000FF'>static</font> cudnnActivationDescriptor_t <b><a name='relu_activation_descriptor'></a>relu_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
thread_local cudnn_activation_descriptor <font color='#BB00BB'>des</font><font face='Lucida Console'>(</font>CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,<font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> des.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>static</font> cudnnActivationDescriptor_t <b><a name='sigmoid_activation_descriptor'></a>sigmoid_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
thread_local cudnn_activation_descriptor <font color='#BB00BB'>des</font><font face='Lucida Console'>(</font>CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,<font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> des.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>static</font> cudnnActivationDescriptor_t <b><a name='tanh_activation_descriptor'></a>tanh_activation_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
thread_local cudnn_activation_descriptor <font color='#BB00BB'>des</font><font face='Lucida Console'>(</font>CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,<font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> des.<font color='#BB00BB'>get_handle</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
tensor_descriptor::
<b><a name='tensor_descriptor'></a>tensor_descriptor</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> : handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>
<b>{</b>
<b>}</b>
tensor_descriptor::
~<b><a name='tensor_descriptor'></a>tensor_descriptor</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_descriptor::
<b><a name='set_size'></a>set_size</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font> n,
<font color='#0000FF'><u>int</u></font> k,
<font color='#0000FF'><u>int</u></font> nr,
<font color='#0000FF'><u>int</u></font> nc
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>cudnnDestroyTensorDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>handle<font face='Lucida Console'>)</font>;
handle <font color='#5555FF'>=</font> nullptr;
<b>}</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>n <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> nr <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> nc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> k <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
cudnnTensorDescriptor_t h;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateTensorDescriptor</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font>h<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
handle <font color='#5555FF'>=</font> h;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetTensor4dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>handle,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
n,
k,
nr,
nc<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_descriptor::
<b><a name='get_size'></a>get_size</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&amp;</font> n,
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&amp;</font> k,
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&amp;</font> nr,
<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&amp;</font> nc
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'><u>int</u></font> nStride, cStride, hStride, wStride;
cudnnDataType_t datatype;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetTensor4dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnTensorDescriptor_t<font face='Lucida Console'>)</font>handle,
<font color='#5555FF'>&amp;</font>datatype,
<font color='#5555FF'>&amp;</font>n,
<font color='#5555FF'>&amp;</font>k,
<font color='#5555FF'>&amp;</font>nr,
<font color='#5555FF'>&amp;</font>nc,
<font color='#5555FF'>&amp;</font>nStride,
<font color='#5555FF'>&amp;</font>cStride,
<font color='#5555FF'>&amp;</font>hStride,
<font color='#5555FF'>&amp;</font>wStride<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
n <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
k <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<b>}</b>
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='add'></a>add</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>float</u></font> beta,
tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'><u>float</u></font> alpha,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
<font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>src, dest<font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font>
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font>
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font>
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font>
<font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>src,dest<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font> ,
"<font color='#CC0000'>\n\t dest.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t dest.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t dest.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t dest.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t src.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t src.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t src.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>"<font color='#CC0000'>\n\t src.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> beta <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Call the dlib function in this case since it's faster than the one that
</font> <font color='#009900'>// comes with cuDNN (at least as of cuDNN v4).
</font> <font color='#BB00BB'>add_scaled</font><font face='Lucida Console'>(</font>dest, alpha, src<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font>;
<b>}</b>
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font>dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>add_cv_to_all_columns</font><font face='Lucida Console'>(</font>beta, dest, alpha, src<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font>;
<b>}</b>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnAddTensor</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='assign_conv_bias_gradient'></a>assign_conv_bias_gradient</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> grad,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
grad.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
grad.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
grad.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
grad.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gradient_input.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> grad.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gradient_input.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font>
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionBackwardBias</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_inference'></a>batch_normalize_inference</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps,
resizable_tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gamma,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> beta,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> running_means,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> running_variances
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, beta<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_means<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_variances<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
eps <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>,
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> eps
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardInference</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_BATCHNORM_PER_ACTIVATION,
<font color='#5555FF'>&amp;</font>in_scale,
<font color='#5555FF'>&amp;</font>out_scale,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>,
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
eps<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize'></a>batch_normalize</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps,
resizable_tensor<font color='#5555FF'>&amp;</font> dest,
resizable_tensor<font color='#5555FF'>&amp;</font> means,
resizable_tensor<font color='#5555FF'>&amp;</font> invstds,
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> averaging_factor,
resizable_tensor<font color='#5555FF'>&amp;</font> running_means,
resizable_tensor<font color='#5555FF'>&amp;</font> running_variances,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gamma,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> beta
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> averaging_factor <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> averaging_factor <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>1</font>, "<font color='#CC0000'>averaging_factor: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> averaging_factor<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_means,means<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_variances,invstds<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
eps <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>,
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> eps
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>;
means.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>1</font>, src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
invstds.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>;
running_means.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>;
running_variances.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>;
<font color='#009900'>// cuDNN requires that running_means and running_variances be initialized to
</font> <font color='#009900'>// some valid float values even if the averaging factor would have ignored
</font> <font color='#009900'>// them.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>averaging_factor <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
running_means <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
running_variances <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<b>}</b>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardTraining</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_BATCHNORM_PER_ACTIVATION,
<font color='#5555FF'>&amp;</font>in_scale,
<font color='#5555FF'>&amp;</font>out_scale,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>,
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
averaging_factor,
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
eps,
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_gradient'></a>batch_normalize_gradient</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> means,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> invstds,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gamma,
tensor<font color='#5555FF'>&amp;</font> src_grad,
tensor<font color='#5555FF'>&amp;</font> gamma_grad,
tensor<font color='#5555FF'>&amp;</font> beta_grad
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> num <font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font>src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font>src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>means.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>invstds.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>beta_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src_grad<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>eps <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale_params <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale_params <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationBackward</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_BATCHNORM_PER_ACTIVATION,
<font color='#5555FF'>&amp;</font>in_scale,
<font color='#5555FF'>&amp;</font>out_scale,
<font color='#5555FF'>&amp;</font>in_scale_params,
<font color='#5555FF'>&amp;</font>out_scale_params,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src_grad<font face='Lucida Console'>)</font>,
src_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>,
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
gamma_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
beta_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
eps,
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_conv_inference'></a>batch_normalize_conv_inference</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps,
resizable_tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gamma,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> beta,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> running_means,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> running_variances
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, beta<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_means<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gamma, running_variances<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
eps <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>,
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_means.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_means.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nrunning_variances.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> running_variances.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> eps
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardInference</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_BATCHNORM_SPATIAL,
<font color='#5555FF'>&amp;</font>in_scale,
<font color='#5555FF'>&amp;</font>out_scale,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>,
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
eps<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_conv'></a>batch_normalize_conv</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps,
resizable_tensor<font color='#5555FF'>&amp;</font> dest,
resizable_tensor<font color='#5555FF'>&amp;</font> means,
resizable_tensor<font color='#5555FF'>&amp;</font> invstds,
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> averaging_factor,
resizable_tensor<font color='#5555FF'>&amp;</font> running_means,
resizable_tensor<font color='#5555FF'>&amp;</font> running_variances,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gamma,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> beta
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> averaging_factor <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> averaging_factor <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>1</font>, "<font color='#CC0000'>averaging_factor: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> averaging_factor<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_means,means<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>averaging_factor<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>1</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>running_variances,invstds<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
eps <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>,
"<font color='#CC0000'>\ngamma.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\ngamma.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.num_samples(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nbeta.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> beta.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.k(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\nsrc.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\neps: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> eps
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
dest.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>;
means.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>1</font>, src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
invstds.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>;
running_means.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>;
running_variances.<font color='#BB00BB'>copy_size</font><font face='Lucida Console'>(</font>means<font face='Lucida Console'>)</font>;
<font color='#009900'>// cuDNN requires that running_means and running_variances be initialized to
</font> <font color='#009900'>// some valid float values even if the averaging factor would have ignored
</font> <font color='#009900'>// them.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>averaging_factor <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
running_means <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
running_variances <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<b>}</b>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationForwardTraining</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_BATCHNORM_SPATIAL,
<font color='#5555FF'>&amp;</font>in_scale,
<font color='#5555FF'>&amp;</font>out_scale,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>,
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
beta.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
averaging_factor,
running_means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
running_variances.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
eps,
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='batch_normalize_conv_gradient'></a>batch_normalize_conv_gradient</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> eps,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> means,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> invstds,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gamma,
tensor<font color='#5555FF'>&amp;</font> src_grad,
tensor<font color='#5555FF'>&amp;</font> gamma_grad,
tensor<font color='#5555FF'>&amp;</font> beta_grad
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>means.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>invstds.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>gamma_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font><font face='Lucida Console'>)</font>beta_grad.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input, src_grad<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>eps <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> in_scale_params <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> out_scale_params <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnBatchNormalizationBackward</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_BATCHNORM_SPATIAL,
<font color='#5555FF'>&amp;</font>in_scale,
<font color='#5555FF'>&amp;</font>out_scale,
<font color='#5555FF'>&amp;</font>in_scale_params,
<font color='#5555FF'>&amp;</font>out_scale_params,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src_grad<font face='Lucida Console'>)</font>,
src_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font>,
gamma.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
gamma_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
beta_grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
eps,
means.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
invstds.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font> <font color='#009900'>// ------------------------------------------------------------------------------------
</font>
tensor_conv::
<b><a name='tensor_conv'></a>tensor_conv</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> :
filter_handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>,
conv_handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>,
forward_algo<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,
backward_data_algo<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,
backward_filters_algo<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::
<b><a name='clear'></a>clear</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>filter_handle<font face='Lucida Console'>)</font>
<font color='#BB00BB'>cudnnDestroyFilterDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>conv_handle<font face='Lucida Console'>)</font>
<font color='#BB00BB'>cudnnDestroyConvolutionDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle<font face='Lucida Console'>)</font>;
filter_handle <font color='#5555FF'>=</font> nullptr;
conv_handle <font color='#5555FF'>=</font> nullptr;
out_num_samples <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
out_k <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
out_nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
out_nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
stride_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
stride_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
padding_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
padding_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
data_num_samples <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
data_k <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
data_nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
data_nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
filters_num_samples <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
filters_k <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
filters_nr <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
filters_nc <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
forward_algo <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
backward_data_algo <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
backward_filters_algo <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
forward_workspace_size_in_bytes <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
backward_data_workspace_size_in_bytes <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
backward_filters_workspace_size_in_bytes <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
forward_workspace.<font color='#BB00BB'>reset</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
backward_data_workspace.<font color='#BB00BB'>reset</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
backward_filters_workspace.<font color='#BB00BB'>reset</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// Given an array of cudnn algorithm performance results, like
</font> <font color='#009900'>// cudnnConvolutionFwdAlgoPerf_t, pick the best one to use.
</font> <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> T<font color='#5555FF'>&gt;</font>
<b><a name='decltype'></a>decltype</b><font face='Lucida Console'>(</font>std::declval<font color='#5555FF'>&lt;</font>T<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.algo<font face='Lucida Console'>)</font> <b><a name='pick_best_algorithm'></a>pick_best_algorithm</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>T<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&amp;</font>perf_results<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#5555FF'>!</font>perf_results.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font>perf_results[<font color='#979000'>0</font>].status<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> perf_results[<font color='#979000'>0</font>].algo;
<font color='#009900'>// Otherwise we find the algorithm that has a good status and uses the least amount
</font> <font color='#009900'>// of memory.
</font> <font color='#0000FF'><u>size_t</u></font> best_memory <font color='#5555FF'>=</font> std::numeric_limits<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>&gt;</font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>decltype</font><font face='Lucida Console'>(</font>std::declval<font color='#5555FF'>&lt;</font>T<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.algo<font face='Lucida Console'>)</font> best_alg;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> perf : perf_results<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>perf.status <font color='#5555FF'>=</font><font color='#5555FF'>=</font> CUDNN_STATUS_SUCCESS <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> perf.memory <font color='#5555FF'>&lt;</font> best_memory<font face='Lucida Console'>)</font>
<b>{</b>
best_memory <font color='#5555FF'>=</font> perf.memory;
best_alg <font color='#5555FF'>=</font> perf.algo;
<b>}</b>
<b>}</b>
<font color='#0000FF'>return</font> best_alg;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::
<b><a name='select_best_algorithms'></a>select_best_algorithms</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> tensor_descriptor<font color='#5555FF'>&amp;</font> dest_desc
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Calling the cuDNN "find the best algorithm" functions are really slow. So we keep a
</font> <font color='#009900'>// cache that tells us what method was best for a particular configuration.
</font> thread_local std::map<font color='#5555FF'>&lt;</font>std::tuple<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>long</u></font>,<font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font>,
std::tuple<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'><u>int</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> config_to_algo_cache;
<font color='#009900'>// If we have already found good algorithms for this setting then just pull them from
</font> <font color='#009900'>// the cache.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> cache_key <font color='#5555FF'>=</font> std::<font color='#BB00BB'>make_tuple</font><font face='Lucida Console'>(</font>stride_y, stride_x, padding_y, padding_x, filters_nr, filters_nc<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> iter <font color='#5555FF'>=</font> config_to_algo_cache.<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font>cache_key<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>iter <font color='#5555FF'>!</font><font color='#5555FF'>=</font> config_to_algo_cache.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
std::<font color='#BB00BB'>tie</font><font face='Lucida Console'>(</font>forward_algo, backward_data_algo, backward_filters_algo<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> iter<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>second;
<font color='#0000FF'>return</font>;
<b>}</b>
<font color='#009900'>// Pick which forward algorithm we will use and allocate the necessary
</font> <font color='#009900'>// workspace buffer.
</font> cudnnConvolutionFwdAlgo_t forward_best_algo;
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>8</font>
<b>{</b>
<font color='#0000FF'><u>int</u></font> num_possible_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionForwardAlgorithmMaxCount</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#5555FF'>&amp;</font>num_possible_algorithms<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'>&lt;</font>cudnnConvolutionFwdAlgoPerf_t<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>perf_results</font><font face='Lucida Console'>(</font>num_possible_algorithms<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>int</u></font> num_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnFindConvolutionForwardAlgorithm</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
num_possible_algorithms,
<font color='#5555FF'>&amp;</font>num_algorithms,
perf_results.<font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
perf_results.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>num_algorithms<font face='Lucida Console'>)</font>;
forward_best_algo <font color='#5555FF'>=</font> <font color='#BB00BB'>pick_best_algorithm</font><font face='Lucida Console'>(</font>perf_results<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>#else</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionForwardAlgorithm</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>&gt;</font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>forward_best_algo<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>#endif</font>
forward_algo <font color='#5555FF'>=</font> forward_best_algo;
<font color='#009900'>// Pick which backward data algorithm we will use and allocate the
</font> <font color='#009900'>// necessary workspace buffer.
</font> cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>8</font>
<b>{</b>
<font color='#0000FF'><u>int</u></font> num_possible_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterAlgorithmMaxCount</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#5555FF'>&amp;</font>num_possible_algorithms<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'>&lt;</font>cudnnConvolutionBwdDataAlgoPerf_t<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>perf_results</font><font face='Lucida Console'>(</font>num_possible_algorithms<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>int</u></font> num_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnFindConvolutionBackwardDataAlgorithm</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
num_possible_algorithms,
<font color='#5555FF'>&amp;</font>num_algorithms,
perf_results.<font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
perf_results.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>num_algorithms<font face='Lucida Console'>)</font>;
backward_data_best_algo <font color='#5555FF'>=</font> <font color='#BB00BB'>pick_best_algorithm</font><font face='Lucida Console'>(</font>perf_results<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>#else</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardDataAlgorithm</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
std::numeric_limits<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>&gt;</font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>backward_data_best_algo<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>#endif</font>
backward_data_algo <font color='#5555FF'>=</font> backward_data_best_algo;
<font color='#009900'>// Pick which backward filters algorithm we will use and allocate the
</font> <font color='#009900'>// necessary workspace buffer.
</font> cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>8</font>
<b>{</b>
<font color='#0000FF'><u>int</u></font> num_possible_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterAlgorithmMaxCount</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#5555FF'>&amp;</font>num_possible_algorithms<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'>&lt;</font>cudnnConvolutionBwdFilterAlgoPerf_t<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>perf_results</font><font face='Lucida Console'>(</font>num_possible_algorithms<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>int</u></font> num_algorithms <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnFindConvolutionBackwardFilterAlgorithm</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
num_possible_algorithms,
<font color='#5555FF'>&amp;</font>num_algorithms,
perf_results.<font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
perf_results.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>num_algorithms<font face='Lucida Console'>)</font>;
backward_filters_best_algo <font color='#5555FF'>=</font> <font color='#BB00BB'>pick_best_algorithm</font><font face='Lucida Console'>(</font>perf_results<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>#else</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterAlgorithm</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
std::numeric_limits<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>&gt;</font>::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>backward_filters_best_algo<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>#endif</font>
<font color='#009900'>// cuDNN 5.1 has a bug that causes
</font> <font color='#009900'>// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
</font> <font color='#009900'>// algorithm even for cases where cuDNN doesn't support it, leading to
</font> <font color='#009900'>// incorrect outputs. So here we check if we are in a case where winograd
</font> <font color='#009900'>// isn't supported and manually overrule
</font> <font color='#009900'>// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
</font> <font color='#009900'>// algorithm.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>dnn_prefer_fastest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#5555FF'>!</font><font face='Lucida Console'>(</font>stride_x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> stride_y <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> <font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>filters_nr<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>3</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>filters_nc<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>3</font><font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font face='Lucida Console'>(</font>filters_nr<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>5</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>filters_nc<font color='#5555FF'>=</font><font color='#5555FF'>=</font><font color='#979000'>5</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font>
<b>{</b>
backward_filters_best_algo <font color='#5555FF'>=</font> CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
<b>}</b>
backward_filters_algo <font color='#5555FF'>=</font> backward_filters_best_algo;
<font color='#009900'>// Save this algorithm selection in the cache
</font> config_to_algo_cache[cache_key] <font color='#5555FF'>=</font> std::<font color='#BB00BB'>make_tuple</font><font face='Lucida Console'>(</font>forward_algo, backward_data_algo, backward_filters_algo<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::
<b><a name='setup'></a>setup</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> filters,
<font color='#0000FF'><u>int</u></font> stride_y_,
<font color='#0000FF'><u>int</u></font> stride_x_,
<font color='#0000FF'><u>int</u></font> padding_y_,
<font color='#0000FF'><u>int</u></font> padding_x_
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// if the last call to setup gave the same exact settings then don't do
</font> <font color='#009900'>// anything.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>data_num_samples <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
data_k <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
data_nr <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
data_nc <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
stride_y_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_y <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
stride_x_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_x <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
padding_y_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_y <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
padding_x_ <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_x <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
filters_num_samples <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
filters_k <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
filters_nr <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
filters_nc <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font>;
<b>}</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>try</font>
<b>{</b>
stride_y <font color='#5555FF'>=</font> stride_y_;
stride_x <font color='#5555FF'>=</font> stride_x_;
padding_y <font color='#5555FF'>=</font> padding_y_;
padding_x <font color='#5555FF'>=</font> padding_x_;
data_num_samples <font color='#5555FF'>=</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_k <font color='#5555FF'>=</font> data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_nr <font color='#5555FF'>=</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_nc <font color='#5555FF'>=</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
filters_num_samples <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
filters_k <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
filters_nr <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
filters_nc <font color='#5555FF'>=</font> filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateFilterDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnFilterDescriptor_t<font color='#5555FF'>*</font><font face='Lucida Console'>)</font><font color='#5555FF'>&amp;</font>filter_handle<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetFilter4dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreateConvolutionDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font color='#5555FF'>*</font><font face='Lucida Console'>)</font><font color='#5555FF'>&amp;</font>conv_handle<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>#if</font> CUDNN_MAJOR <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>6</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetConvolution2dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
padding_y, <font color='#009900'>// vertical padding
</font> padding_x, <font color='#009900'>// horizontal padding
</font> stride_y,
stride_x,
<font color='#979000'>1</font>, <font color='#979000'>1</font>, <font color='#009900'>// must be 1,1
</font> CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// could also be CUDNN_CONVOLUTION
</font><font color='#0000FF'>#else</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetConvolution2dDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
padding_y, <font color='#009900'>// vertical padding
</font> padding_x, <font color='#009900'>// horizontal padding
</font> stride_y,
stride_x,
<font color='#979000'>1</font>, <font color='#979000'>1</font>, <font color='#009900'>// must be 1,1
</font> CUDNN_CROSS_CORRELATION<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// could also be CUDNN_CONVOLUTION
</font><font color='#0000FF'>#endif</font>
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolution2dForwardOutputDim</font><font face='Lucida Console'>(</font>
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font color='#5555FF'>&amp;</font>out_num_samples,
<font color='#5555FF'>&amp;</font>out_k,
<font color='#5555FF'>&amp;</font>out_nr,
<font color='#5555FF'>&amp;</font>out_nc<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
tensor_descriptor dest_desc;
dest_desc.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font>out_num_samples,out_k,out_nr,out_nc<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>select_best_algorithms</font><font face='Lucida Console'>(</font>data, dest_desc<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionForwardWorkspaceSize</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font>cudnnConvolutionFwdAlgo_t<font face='Lucida Console'>)</font>forward_algo,
<font color='#5555FF'>&amp;</font>forward_workspace_size_in_bytes<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardDataWorkspaceSize</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font>cudnnConvolutionBwdDataAlgo_t<font face='Lucida Console'>)</font>backward_data_algo,
<font color='#5555FF'>&amp;</font>backward_data_workspace_size_in_bytes<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetConvolutionBackwardFilterWorkspaceSize</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest_desc<font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
<font face='Lucida Console'>(</font>cudnnConvolutionBwdFilterAlgo_t<font face='Lucida Console'>)</font>backward_filters_algo,
<font color='#5555FF'>&amp;</font>backward_filters_workspace_size_in_bytes<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>...<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>throw</font>;
<b>}</b>
<b>}</b>
tensor_conv::
~<b><a name='tensor_conv'></a>tensor_conv</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::<b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output,
resizable_tensor<font color='#5555FF'>&amp;</font> output,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> filters
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>stride_y <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> stride_x <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>, "<font color='#CC0000'>You must call setup() before calling this function</font>"<font face='Lucida Console'>)</font>;
output.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font>out_num_samples, out_k, out_nr, out_nc<font face='Lucida Console'>)</font>;
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>add_to_output, <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font>tensor<font color='#5555FF'>&amp;</font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>output<font face='Lucida Console'>)</font>, data, filters<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::<b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output,
tensor<font color='#5555FF'>&amp;</font> output,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> filters
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>output,data<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>output,filters<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>false</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>filters.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>stride_y <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> stride_x <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>, "<font color='#CC0000'>You must call setup() before calling this function</font>"<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x,
"<font color='#CC0000'>Filter windows must be small enough to fit into the padded image.</font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t filters.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t data.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t padding_x: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_x
<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y,
"<font color='#CC0000'>Filter windows must be small enough to fit into the padded image.</font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t filters.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t data.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t padding_y: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_y
<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,out_num_samples <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> data.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> filters.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font color='#5555FF'>+</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font><font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y<font color='#5555FF'>-</font>filters.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_y<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>output.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font color='#5555FF'>+</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font><font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x<font color='#5555FF'>-</font>filters.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_x<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> add_to_output ? <font color='#979000'>1</font> : <font color='#979000'>0</font>;
<font color='#009900'>// Since cudnnConvolutionForward() is an asynchronous call, we need to hold a
</font> <font color='#009900'>// reference to the workspace buffer so we can be sure it isn't reallocated
</font> <font color='#009900'>// while the function is still executing on the device. But each time we come
</font> <font color='#009900'>// here, we make sure to grab the latest workspace buffer so that, globally, we
</font> <font color='#009900'>// minimize the number of such buffers.
</font> forward_workspace <font color='#5555FF'>=</font> <font color='#BB00BB'>device_global_buffer</font><font face='Lucida Console'>(</font>forward_workspace_size_in_bytes<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionForward</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
data.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
filters.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font face='Lucida Console'>(</font>cudnnConvolutionFwdAlgo_t<font face='Lucida Console'>)</font>forward_algo,
forward_workspace,
forward_workspace_size_in_bytes,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>output<font face='Lucida Console'>)</font>,
output.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::<b><a name='get_gradient_for_data'></a>get_gradient_for_data</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> filters,
tensor<font color='#5555FF'>&amp;</font> data_gradient
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> add_to_output ? <font color='#979000'>1</font> : <font color='#979000'>0</font>;
<font color='#009900'>// Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a
</font> <font color='#009900'>// reference to the workspace buffer so we can be sure it isn't reallocated
</font> <font color='#009900'>// while the function is still executing on the device. But each time we come
</font> <font color='#009900'>// here, we make sure to grab the latest workspace buffer so that, globally, we
</font> <font color='#009900'>// minimize the number of such buffers.
</font> backward_data_workspace <font color='#5555FF'>=</font> <font color='#BB00BB'>device_global_buffer</font><font face='Lucida Console'>(</font>backward_data_workspace_size_in_bytes<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionBackwardData</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
filters.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font face='Lucida Console'>(</font>cudnnConvolutionBwdDataAlgo_t<font face='Lucida Console'>)</font>backward_data_algo,
backward_data_workspace,
backward_data_workspace_size_in_bytes,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data_gradient<font face='Lucida Console'>)</font>,
data_gradient.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> tensor_conv::
<b><a name='get_gradient_for_filters'></a>get_gradient_for_filters</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> add_to_output,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> data,
tensor<font color='#5555FF'>&amp;</font> filters_gradient
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> add_to_output ? <font color='#979000'>1</font> : <font color='#979000'>0</font>;
<font color='#009900'>// Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a
</font> <font color='#009900'>// reference to the workspace buffer so we can be sure it isn't reallocated
</font> <font color='#009900'>// while the function is still executing on the device. But each time we come
</font> <font color='#009900'>// here, we make sure to grab the latest workspace buffer so that, globally, we
</font> <font color='#009900'>// minimize the number of such buffers.
</font> backward_filters_workspace <font color='#5555FF'>=</font> <font color='#BB00BB'>device_global_buffer</font><font face='Lucida Console'>(</font>backward_filters_workspace_size_in_bytes<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnConvolutionBackwardFilter</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>data<font face='Lucida Console'>)</font>,
data.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnConvolutionDescriptor_t<font face='Lucida Console'>)</font>conv_handle,
<font face='Lucida Console'>(</font>cudnnConvolutionBwdFilterAlgo_t<font face='Lucida Console'>)</font>backward_filters_algo,
backward_filters_workspace,
backward_filters_workspace_size_in_bytes,
<font color='#5555FF'>&amp;</font>beta,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnFilterDescriptor_t<font face='Lucida Console'>)</font>filter_handle,
filters_gradient.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font> <font color='#009900'>// ------------------------------------------------------------------------------------
</font>
pooling::<b><a name='pooling'></a>pooling</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> : handle<font face='Lucida Console'>(</font>nullptr<font face='Lucida Console'>)</font>,window_height<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,window_width<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,stride_y<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,stride_x<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>,padding_y<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, padding_x<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
<b>}</b>
pooling::~<b><a name='pooling'></a>pooling</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> pooling::
<b><a name='clear'></a>clear</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>handle<font face='Lucida Console'>)</font>
<font color='#BB00BB'>cudnnDestroyPoolingDescriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font>cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle<font face='Lucida Console'>)</font>;
handle <font color='#5555FF'>=</font> nullptr;
window_height <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
window_width <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
stride_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
stride_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
padding_y <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
padding_x <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> pooling::
<b><a name='setup_max_pooling'></a>setup_max_pooling</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font> window_height_,
<font color='#0000FF'><u>int</u></font> window_width_,
<font color='#0000FF'><u>int</u></font> stride_y_,
<font color='#0000FF'><u>int</u></font> stride_x_,
<font color='#0000FF'><u>int</u></font> padding_y_,
<font color='#0000FF'><u>int</u></font> padding_x_
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>setup</font><font face='Lucida Console'>(</font>window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_MAX<font face='Lucida Console'>)</font>;
do_max_pooling <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> pooling::
<b><a name='setup_avg_pooling'></a>setup_avg_pooling</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font> window_height_,
<font color='#0000FF'><u>int</u></font> window_width_,
<font color='#0000FF'><u>int</u></font> stride_y_,
<font color='#0000FF'><u>int</u></font> stride_x_,
<font color='#0000FF'><u>int</u></font> padding_y_,
<font color='#0000FF'><u>int</u></font> padding_x_
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>setup</font><font face='Lucida Console'>(</font>window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING<font face='Lucida Console'>)</font>;
do_max_pooling <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> pooling::
<b><a name='setup'></a>setup</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font> window_height_,
<font color='#0000FF'><u>int</u></font> window_width_,
<font color='#0000FF'><u>int</u></font> stride_y_,
<font color='#0000FF'><u>int</u></font> stride_x_,
<font color='#0000FF'><u>int</u></font> padding_y_,
<font color='#0000FF'><u>int</u></font> padding_x_,
<font color='#0000FF'><u>int</u></font> pooling_mode
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font> <font face='Lucida Console'>(</font>window_height_ <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> window_width_ <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
stride_y_ <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> stride_x_ <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font> ,
"<font color='#CC0000'>window_height_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_height_
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\t\n window_width_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_width_
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\t\n stride_y_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> stride_y_
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\t\n stride_x_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> stride_x_ <font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font> <font color='#979000'>0</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> padding_y_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> padding_y_ <font color='#5555FF'>&lt;</font> window_height_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#979000'>0</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> padding_x_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> padding_x_ <font color='#5555FF'>&lt;</font> window_width_,
"<font color='#CC0000'>window_height_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_height_
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\t\n window_width_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_width_
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\t\n padding_y_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_y_
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\t\n padding_x_: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_x_ <font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>window_height <font color='#5555FF'>=</font><font color='#5555FF'>=</font> window_height_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
window_width <font color='#5555FF'>=</font><font color='#5555FF'>=</font> window_width_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
stride_y <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_y_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
stride_x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> stride_x_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
padding_y <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_y_ <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
padding_x <font color='#5555FF'>=</font><font color='#5555FF'>=</font> padding_x_
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font>;
<b>}</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>try</font>
<b>{</b>
window_height <font color='#5555FF'>=</font> window_height_;
window_width <font color='#5555FF'>=</font> window_width_;
stride_x <font color='#5555FF'>=</font> stride_x_;
stride_y <font color='#5555FF'>=</font> stride_y_;
padding_y <font color='#5555FF'>=</font> padding_y_;
padding_x <font color='#5555FF'>=</font> padding_x_;
cudnnPoolingDescriptor_t poolingDesc;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnCreatePoolingDescriptor</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font>poolingDesc<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
handle <font color='#5555FF'>=</font> poolingDesc;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSetPooling2dDescriptor</font><font face='Lucida Console'>(</font>poolingDesc,
<font face='Lucida Console'>(</font>cudnnPoolingMode_t<font face='Lucida Console'>)</font>pooling_mode,
CUDNN_PROPAGATE_NAN,
window_height,
window_width,
padding_y,
padding_x,
stride_y,
stride_x<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>...<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>throw</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'><u>void</u></font> pooling::
<b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font>
resizable_tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>window_width <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x,
"<font color='#CC0000'>Pooling windows must be small enough to fit into the padded image.</font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t window_width: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_width
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t src.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t padding_x: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_x
<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>window_height <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y,
"<font color='#CC0000'>Pooling windows must be small enough to fit into the padded image.</font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t window_height: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_height
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t src.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t padding_y: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_y
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'><u>int</u></font> outN;
<font color='#0000FF'><u>int</u></font> outC;
<font color='#0000FF'><u>int</u></font> outH;
<font color='#0000FF'><u>int</u></font> outW;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnGetPooling2dForwardOutputDim</font><font face='Lucida Console'>(</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>outN,
<font color='#5555FF'>&amp;</font>outC,
<font color='#5555FF'>&amp;</font>outH,
<font color='#5555FF'>&amp;</font>outW<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
dest.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font>outN,outC,outH,outW<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> src.<font color='#BB00BB'>k</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>+</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_y <font color='#5555FF'>-</font> window_height<font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_y,
"<font color='#CC0000'>\n stride_y: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> stride_y <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n padding_y: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_y <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n window_height: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_height <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n src.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n dest.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dest.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n src.nr()/stride_y: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_y<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font> <font color='#5555FF'>+</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>padding_x <font color='#5555FF'>-</font> window_width<font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_x,
"<font color='#CC0000'>\n stride_x: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> stride_x <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n padding_x: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> padding_x <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n window_width: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> window_width <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n src.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n dest.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dest.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\n src.nc()/stride_x: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> src.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>stride_x<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnPoolingForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> pooling::<b><a name='get_gradient'></a>get_gradient</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src,
tensor<font color='#5555FF'>&amp;</font> grad
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>gradient_input,dest<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>src,grad<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnPoolingBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> cudnnPoolingDescriptor_t<font face='Lucida Console'>)</font>handle,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font> <font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='softmax'></a>softmax</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='softmax_gradient'></a>softmax_gradient</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> grad,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font> <font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='softmax_all'></a>softmax_all</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='softmax_all_gradient'></a>softmax_all_gradient</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> grad,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnSoftmaxBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font> <font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='sigmoid'></a>sigmoid</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>sigmoid_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='sigmoid_gradient'></a>sigmoid_gradient</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> grad,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>sigmoid_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='relu'></a>relu</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>relu_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='relu_gradient'></a>relu_gradient</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> grad,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>relu_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='tanh'></a>tanh</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> src
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,src<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>src.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationForward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>tanh_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>src<font face='Lucida Console'>)</font>,
src.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='tanh_gradient'></a>tanh_gradient</b> <font face='Lucida Console'>(</font>
tensor<font color='#5555FF'>&amp;</font> grad,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> dest,
<font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> gradient_input
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,gradient_input<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font>
<font color='#BB00BB'>have_same_dimensions</font><font face='Lucida Console'>(</font>dest,grad<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>true</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>dest.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> alpha <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>float</u></font> beta <font color='#5555FF'>=</font> <font color='#BB00BB'>is_same_object</font><font face='Lucida Console'>(</font>grad,gradient_input<font face='Lucida Console'>)</font> ? <font color='#979000'>0</font> : <font color='#979000'>1</font>;
<font color='#BB00BB'>CHECK_CUDNN</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cudnnActivationBackward</font><font face='Lucida Console'>(</font><font color='#BB00BB'>context</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>tanh_activation_descriptor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>alpha,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>gradient_input<font face='Lucida Console'>)</font>,
gradient_input.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>dest<font face='Lucida Console'>)</font>,
dest.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
<font color='#5555FF'>&amp;</font>beta,
<font color='#BB00BB'>descriptor</font><font face='Lucida Console'>(</font>grad<font face='Lucida Console'>)</font>,
grad.<font color='#BB00BB'>device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ------------------------------------------------------------------------------------
</font> <b>}</b>
<b>}</b>
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_USE_CUDA
</font>
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNN_CuDNN_CPP_
</font>
</pre></body></html>