|
<html><head><title>dlib C++ Library - utilities.h</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// Copyright (C) 2016 Davis E. King ([email protected]) |
|
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license. |
|
</font><font color='#0000FF'>#ifndef</font> DLIB_DNn_UTILITIES_H_ |
|
<font color='#0000FF'>#define</font> DLIB_DNn_UTILITIES_H_ |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='core.h.html'>core.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='utilities_abstract.h.html'>utilities_abstract.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../geometry.h.html'>../geometry.h</a>" |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>fstream<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>namespace</font> dlib |
|
<b>{</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>inline</font> <font color='#0000FF'><u>void</u></font> <b><a name='randomize_parameters'></a>randomize_parameters</b> <font face='Lucida Console'>(</font> |
|
tensor<font color='#5555FF'>&</font> params, |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num_inputs_and_outputs, |
|
dlib::rand<font color='#5555FF'>&</font> rnd |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font> val : params<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Draw a random number to initialize the layer according to formula (16) |
|
</font> <font color='#009900'>// from Understanding the difficulty of training deep feedforward neural |
|
</font> <font color='#009900'>// networks by Xavier Glorot and Yoshua Bengio. |
|
</font> val <font color='#5555FF'>=</font> <font color='#979000'>2</font><font color='#5555FF'>*</font>rnd.<font color='#BB00BB'>get_random_float</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font><font color='#979000'>1</font>; |
|
val <font color='#5555FF'>*</font><font color='#5555FF'>=</font> std::<font color='#BB00BB'>sqrt</font><font face='Lucida Console'>(</font><font color='#979000'>6.0</font><font color='#5555FF'>/</font><font face='Lucida Console'>(</font>num_inputs_and_outputs<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>namespace</font> impl |
|
<b>{</b> |
|
<font color='#0000FF'>class</font> <b><a name='visitor_net_to_xml'></a>visitor_net_to_xml</b> |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='visitor_net_to_xml'></a>visitor_net_to_xml</b><font face='Lucida Console'>(</font>std::ostream<font color='#5555FF'>&</font> out_<font face='Lucida Console'>)</font> : out<font face='Lucida Console'>(</font>out_<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font> input_layer_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> idx, <font color='#0000FF'>const</font> input_layer_type<font color='#5555FF'>&</font> l<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'><layer idx='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font>idx<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>' type='input'>\n</font>"; |
|
<font color='#BB00BB'>to_xml</font><font face='Lucida Console'>(</font>l,out<font face='Lucida Console'>)</font>; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'></layer>\n</font>"; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> idx, <font color='#0000FF'>const</font> add_loss_layer<font color='#5555FF'><</font>T,U<font color='#5555FF'>></font><font color='#5555FF'>&</font> l<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'><layer idx='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font>idx<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>' type='loss'>\n</font>"; |
|
<font color='#BB00BB'>to_xml</font><font face='Lucida Console'>(</font>l.<font color='#BB00BB'>loss_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,out<font face='Lucida Console'>)</font>; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'></layer>\n</font>"; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> idx, <font color='#0000FF'>const</font> add_layer<font color='#5555FF'><</font>T,U,E<font color='#5555FF'>></font><font color='#5555FF'>&</font> l<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'><layer idx='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font>idx<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>' type='comp'>\n</font>"; |
|
<font color='#BB00BB'>to_xml</font><font face='Lucida Console'>(</font>l.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,out<font face='Lucida Console'>)</font>; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'></layer>\n</font>"; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> ID, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> idx, <font color='#0000FF'>const</font> add_tag_layer<font color='#5555FF'><</font>ID,U,E<font color='#5555FF'>></font><font color='#5555FF'>&</font> <font color='#009900'>/*l*/</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'><layer idx='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font>idx<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>' type='tag' id='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font>ID<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>'/>\n</font>"; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='T'></a>T</b>, <font color='#0000FF'>typename</font> U<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> idx, <font color='#0000FF'>const</font> add_skip_layer<font color='#5555FF'><</font>T,U<font color='#5555FF'>></font><font color='#5555FF'>&</font> <font color='#009900'>/*l*/</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'><layer idx='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font>idx<font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>' type='skip' id='</font>"<font color='#5555FF'><</font><font color='#5555FF'><</font><font face='Lucida Console'>(</font>tag_id<font color='#5555FF'><</font>T<font color='#5555FF'>></font>::id<font face='Lucida Console'>)</font><font color='#5555FF'><</font><font color='#5555FF'><</font>"<font color='#CC0000'>'/>\n</font>"; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>private</font>: |
|
|
|
std::ostream<font color='#5555FF'>&</font> out; |
|
<b>}</b>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='net_to_xml'></a>net_to_xml</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> net_type<font color='#5555FF'>&</font> net, |
|
std::ostream<font color='#5555FF'>&</font> out |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font> old_precision <font color='#5555FF'>=</font> out.<font color='#BB00BB'>precision</font><font face='Lucida Console'>(</font><font color='#979000'>9</font><font face='Lucida Console'>)</font>; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'><net>\n</font>"; |
|
<font color='#BB00BB'>visit_layers</font><font face='Lucida Console'>(</font>net, impl::<font color='#BB00BB'>visitor_net_to_xml</font><font face='Lucida Console'>(</font>out<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'></net>\n</font>"; |
|
<font color='#009900'>// restore the original stream precision. |
|
</font> out.<font color='#BB00BB'>precision</font><font face='Lucida Console'>(</font>old_precision<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='net_to_xml'></a>net_to_xml</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> net_type<font color='#5555FF'>&</font> net, |
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> filename |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::ofstream <font color='#BB00BB'>fout</font><font face='Lucida Console'>(</font>filename<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>net_to_xml</font><font face='Lucida Console'>(</font>net, fout<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>namespace</font> impl |
|
<b>{</b> |
|
|
|
<font color='#0000FF'>class</font> <b><a name='visitor_net_map_input_to_output'></a>visitor_net_map_input_to_output</b> |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='visitor_net_map_input_to_output'></a>visitor_net_map_input_to_output</b><font face='Lucida Console'>(</font>dpoint<font color='#5555FF'>&</font> p_<font face='Lucida Console'>)</font> : p<font face='Lucida Console'>(</font>p_<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
dpoint<font color='#5555FF'>&</font> p; |
|
|
|
<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font> input_layer_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> input_layer_type<font color='#5555FF'>&</font> <font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_loss_layer<font color='#5555FF'><</font>T,U<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_layer<font color='#5555FF'><</font>T,U,E<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
p <font color='#5555FF'>=</font> net.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>map_input_to_output</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font> B, <font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dimpl::subnet_wrapper<font color='#5555FF'><</font>add_layer<font color='#5555FF'><</font>T,U,E<font color='#5555FF'>></font>,B<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
p <font color='#5555FF'>=</font> net.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>map_input_to_output</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> ID, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_tag_layer<font color='#5555FF'><</font>ID,U,E<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// tag layers are an identity transform, so do nothing |
|
</font> <font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font> is_first, <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> ID, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dimpl::subnet_wrapper<font color='#5555FF'><</font>add_tag_layer<font color='#5555FF'><</font>ID,U,E<font color='#5555FF'>></font>,is_first<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// tag layers are an identity transform, so do nothing |
|
</font> <font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='TAG_TYPE'></a>TAG_TYPE</b>, <font color='#0000FF'>typename</font> U<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_skip_layer<font color='#5555FF'><</font>TAG_TYPE,U<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>layer<font color='#5555FF'><</font>TAG_TYPE<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font> is_first, <font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='TAG_TYPE'></a>TAG_TYPE</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dimpl::subnet_wrapper<font color='#5555FF'><</font>add_skip_layer<font color='#5555FF'><</font>TAG_TYPE,SUBNET<font color='#5555FF'>></font>,is_first<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// skip layers are an identity transform, so do nothing |
|
</font> <font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>layer<font color='#5555FF'><</font>TAG_TYPE<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>class</font> <b><a name='visitor_net_map_output_to_input'></a>visitor_net_map_output_to_input</b> |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
<b><a name='visitor_net_map_output_to_input'></a>visitor_net_map_output_to_input</b><font face='Lucida Console'>(</font>dpoint<font color='#5555FF'>&</font> p_<font face='Lucida Console'>)</font> : p<font face='Lucida Console'>(</font>p_<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
dpoint<font color='#5555FF'>&</font> p; |
|
|
|
<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font> input_layer_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> input_layer_type<font color='#5555FF'>&</font> <font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_loss_layer<font color='#5555FF'><</font>T,U<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_layer<font color='#5555FF'><</font>T,U,E<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
p <font color='#5555FF'>=</font> net.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>map_output_to_input</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font> B, <font color='#0000FF'>typename</font> T, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dimpl::subnet_wrapper<font color='#5555FF'><</font>add_layer<font color='#5555FF'><</font>T,U,E<font color='#5555FF'>></font>,B<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
p <font color='#5555FF'>=</font> net.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>map_output_to_input</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> ID, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_tag_layer<font color='#5555FF'><</font>ID,U,E<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// tag layers are an identity transform, so do nothing |
|
</font> <font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font> is_first, <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> ID, <font color='#0000FF'>typename</font> U, <font color='#0000FF'>typename</font> E<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dimpl::subnet_wrapper<font color='#5555FF'><</font>add_tag_layer<font color='#5555FF'><</font>ID,U,E<font color='#5555FF'>></font>,is_first<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// tag layers are an identity transform, so do nothing |
|
</font> <font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='TAG_TYPE'></a>TAG_TYPE</b>, <font color='#0000FF'>typename</font> U<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> add_skip_layer<font color='#5555FF'><</font>TAG_TYPE,U<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>layer<font color='#5555FF'><</font>TAG_TYPE<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font> is_first, <font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='TAG_TYPE'></a>TAG_TYPE</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dimpl::subnet_wrapper<font color='#5555FF'><</font>add_skip_layer<font color='#5555FF'><</font>TAG_TYPE,SUBNET<font color='#5555FF'>></font>,is_first<font color='#5555FF'>></font><font color='#5555FF'>&</font> net<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// skip layers are an identity transform, so do nothing |
|
</font> <font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>layer<font color='#5555FF'><</font>TAG_TYPE<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<b>}</b>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'>inline</font> dpoint <b><a name='input_tensor_to_output_tensor'></a>input_tensor_to_output_tensor</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> net_type<font color='#5555FF'>&</font> net, |
|
dpoint p |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
impl::visitor_net_map_input_to_output <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> p; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'>inline</font> dpoint <b><a name='output_tensor_to_input_tensor'></a>output_tensor_to_input_tensor</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> net_type<font color='#5555FF'>&</font> net, |
|
dpoint p |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
impl::visitor_net_map_output_to_input <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> p; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>size_t</u></font> <b><a name='count_parameters'></a>count_parameters</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> net_type<font color='#5555FF'>&</font> net |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>size_t</u></font> num_parameters <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>visit_layer_parameters</font><font face='Lucida Console'>(</font>net, [<font color='#5555FF'>&</font>]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> t<font face='Lucida Console'>)</font> <b>{</b> num_parameters <font color='#5555FF'>+</font><font color='#5555FF'>=</font> t.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> num_parameters; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>namespace</font> impl |
|
<b>{</b> |
|
<font color='#0000FF'>class</font> <b><a name='visitor_learning_rate_multiplier'></a>visitor_learning_rate_multiplier</b> |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
<b><a name='visitor_learning_rate_multiplier'></a>visitor_learning_rate_multiplier</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> new_learning_rate_multiplier_<font face='Lucida Console'>)</font> : |
|
new_learning_rate_multiplier<font face='Lucida Console'>(</font>new_learning_rate_multiplier_<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> layer<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>layer<font color='#5555FF'>&</font> l<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>set_learning_rate_multiplier</font><font face='Lucida Console'>(</font>l, new_learning_rate_multiplier<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>private</font>: |
|
|
|
<font color='#0000FF'><u>double</u></font> new_learning_rate_multiplier; |
|
<b>}</b>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_all_learning_rate_multipliers'></a>set_all_learning_rate_multipliers</b><font face='Lucida Console'>(</font> |
|
net_type<font color='#5555FF'>&</font> net, |
|
<font color='#0000FF'><u>double</u></font> learning_rate_multiplier |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>learning_rate_multiplier <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
impl::visitor_learning_rate_multiplier <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>learning_rate_multiplier<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>visit_computational_layers</font><font face='Lucida Console'>(</font>net, temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font> begin, <font color='#0000FF'><u>size_t</u></font> end, <font color='#0000FF'>typename</font> net_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate_multipliers_range'></a>set_learning_rate_multipliers_range</b><font face='Lucida Console'>(</font> |
|
net_type<font color='#5555FF'>&</font> net, |
|
<font color='#0000FF'><u>double</u></font> learning_rate_multiplier |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>static_assert</font><font face='Lucida Console'>(</font>begin <font color='#5555FF'><</font><font color='#5555FF'>=</font> end, "<font color='#CC0000'>Invalid range</font>"<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>static_assert</font><font face='Lucida Console'>(</font>end <font color='#5555FF'><</font><font color='#5555FF'>=</font> net_type::num_layers, "<font color='#CC0000'>Invalid range</font>"<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>learning_rate_multiplier <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
impl::visitor_learning_rate_multiplier <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>learning_rate_multiplier<font face='Lucida Console'>)</font>; |
|
visit_computational_layers_range<font color='#5555FF'><</font>begin, end<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net, temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font><b>}</b> |
|
|
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNn_UTILITIES_H_ |
|
</font> |
|
|
|
|
|
|
|
</pre></body></html> |