AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - trainer.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2015 Davis E. King ([email protected])
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license.
</font><font color='#0000FF'>#ifndef</font> DLIB_DNn_TRAINER_H_
<font color='#0000FF'>#define</font> DLIB_DNn_TRAINER_H_
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='trainer_abstract.h.html'>trainer_abstract.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='core.h.html'>core.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='solvers.h.html'>solvers.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../statistics.h.html'>../statistics.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>chrono<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>fstream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>sstream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../serialize.h.html'>../serialize.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../pipe.h.html'>../pipe.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../threads.h.html'>../threads.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../cuda/cuda_dlib.h.html'>../cuda/cuda_dlib.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../statistics/running_gradient.h.html'>../statistics/running_gradient.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>atomic<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>cstdio<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>set<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>future<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>exception<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>mutex<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../dir_nav.h.html'>../dir_nav.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../md5.h.html'>../md5.h</a>"
<font color='#0000FF'>namespace</font> dlib
<b>{</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>namespace</font> impl
<b>{</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> training_label_type<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>struct</font> <b><a name='dnn_job_t'></a>dnn_job_t</b>
<b>{</b>
<b><a name='dnn_job_t'></a>dnn_job_t</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>default</font>;
<b><a name='dnn_job_t'></a>dnn_job_t</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_job_t<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
dnn_job_t<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_job_t<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> labels;
std::vector<font color='#5555FF'>&lt;</font>resizable_tensor<font color='#5555FF'>&gt;</font> t;
std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font><font color='#5555FF'>&gt;</font> have_data; <font color='#009900'>// have_data[i] is true if there is data in labels[i] and t[i].
</font> <font color='#0000FF'><u>bool</u></font> test_only <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<b>}</b>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> training_label_type<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='swap'></a>swap</b><font face='Lucida Console'>(</font>dnn_job_t<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> a, dnn_job_t<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> b<font face='Lucida Console'>)</font>
<b>{</b>
a.labels.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b.labels<font face='Lucida Console'>)</font>;
a.t.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b.t<font face='Lucida Console'>)</font>;
a.have_data.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b.have_data<font face='Lucida Console'>)</font>;
std::<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>a.test_only,b.test_only<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>enum</font> <font color='#0000FF'>class</font> <b><a name='force_flush_to_disk'></a>force_flush_to_disk</b> <b>{</b>
no <font color='#5555FF'>=</font> <font color='#979000'>0</font>,
yes <font color='#5555FF'>=</font> <font color='#979000'>1</font>
<b>}</b>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> net_type,
<font color='#0000FF'>typename</font> solver_type <font color='#5555FF'>=</font> sgd
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>class</font> <b><a name='dnn_trainer'></a>dnn_trainer</b> : <font color='#0000FF'>private</font> threaded_object
<b>{</b>
<font color='#0000FF'>public</font>:
<b><a name='static_assert'></a>static_assert</b><font face='Lucida Console'>(</font>is_loss_layer_type<font color='#5555FF'>&lt;</font>net_type<font color='#5555FF'>&gt;</font>::value,
"<font color='#CC0000'>The last layer in a network must be a loss layer.</font>"<font face='Lucida Console'>)</font>;
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> net_type::training_label_type training_label_type;
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> net_type::input_type input_type;
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>size_t</u></font> num_computational_layers <font color='#5555FF'>=</font> net_type::num_computational_layers;
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>size_t</u></font> num_layers <font color='#5555FF'>=</font> net_type::num_layers;
<font color='#0000FF'>using</font> threads <font color='#5555FF'>=</font> std::vector<font color='#5555FF'>&lt;</font>std::shared_ptr<font color='#5555FF'>&lt;</font>thread_pool<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#0000FF'>private</font>:
<font color='#0000FF'>typedef</font> impl::dnn_job_t<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font> job_t;
<font color='#0000FF'>public</font>:
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_trainer<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
dnn_trainer<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_trainer<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>;
<font color='#0000FF'>explicit</font> <b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font>net_type<font color='#5555FF'>&amp;</font> net_<font face='Lucida Console'>)</font> : job_pipe<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font>
<b>{</b>
solver_type default_solver;
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'>&lt;</font>device_data<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, net, default_solver<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>init</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font>
net_type<font color='#5555FF'>&amp;</font> net_,
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&amp;</font> solver_
<font face='Lucida Console'>)</font> : job_pipe<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font>
<b>{</b>
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'>&lt;</font>device_data<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, net, solver_<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>init</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font>
net_type<font color='#5555FF'>&amp;</font> net_,
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&amp;</font> solver_,
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> cuda_extra_devices,
std::shared_ptr<font color='#5555FF'>&lt;</font>threads<font color='#5555FF'>&gt;</font> thread_pools_ <font color='#5555FF'>=</font> std::shared_ptr<font color='#5555FF'>&lt;</font>threads<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font> : job_pipe<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, thread_pools<font face='Lucida Console'>(</font>thread_pools_<font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font>
<b>{</b>
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'>&lt;</font>device_data<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, net, solver_<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> total_devices <font color='#5555FF'>=</font> dlib::cuda::<font color='#BB00BB'>get_num_devices</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Make device contexts for the extra device ids but be careful to avoid any
</font> <font color='#009900'>// duplicate ids.
</font> std::set<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>int</u></font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>cuda_extra_devices.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, cuda_extra_devices.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
temp.<font color='#BB00BB'>erase</font><font face='Lucida Console'>(</font>devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>device_id<font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> id : temp<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> id <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> id <font color='#5555FF'>&lt;</font> total_devices, "<font color='#CC0000'>Invalid CUDA device id given to dnn_trainer.</font>"<font face='Lucida Console'>)</font>;
<font color='#009900'>// Switch to this device so that any tensor objects that get allocated when
</font> <font color='#009900'>// we create the device context happen on this device.
</font> dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>id<font face='Lucida Console'>)</font>;
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'>&lt;</font>device_data<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>id, net, solver_, <font color='#BB00BB'>clone_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// Set the current device back to what it was before this constructor was
</font> <font color='#009900'>// called.
</font> dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>device_id<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>init</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
~<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
job_pipe.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>stop</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>wait</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
net_type<font color='#5555FF'>&amp;</font> <b><a name='get_net'></a>get_net</b> <font face='Lucida Console'>(</font>
force_flush_to_disk force_flush <font color='#5555FF'>=</font> force_flush_to_disk::yes
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font>force_flush <font color='#5555FF'>=</font><font color='#5555FF'>=</font> force_flush_to_disk::yes<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>propagate_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> net;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_mini_batch_size'></a>get_mini_batch_size</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> <b>{</b> <font color='#0000FF'>return</font> mini_batch_size; <b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_mini_batch_size'></a>set_mini_batch_size</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> batch_size
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>batch_size <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
mini_batch_size <font color='#5555FF'>=</font> batch_size;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_max_num_epochs'></a>get_max_num_epochs</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> <b>{</b> <font color='#0000FF'>return</font> max_num_epochs; <b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_max_num_epochs'></a>set_max_num_epochs</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
max_num_epochs <font color='#5555FF'>=</font> num;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='be_verbose'></a>be_verbose</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
verbose <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='be_quiet'></a>be_quiet</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
verbose <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<b>}</b>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>solver_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> <b><a name='get_solvers'></a>get_solvers</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>propagate_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>solvers;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>input_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> labels
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> data_iterator,
<font color='#0000FF'>typename</font> label_iterator
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font>
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, dbegin, dend, lbegin<font face='Lucida Console'>)</font>;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>train_one_step_calls;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>input_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> data
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> data_iterator
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font>
data_iterator dbegin,
data_iterator dend
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, dbegin, dend<font face='Lucida Console'>)</font>;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>train_one_step_calls;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>input_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> labels
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>test_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> data_iterator,
<font color='#0000FF'>typename</font> label_iterator
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font>
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>true</font>, dbegin, dend, lbegin<font face='Lucida Console'>)</font>;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>test_one_step_calls;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>input_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> data
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>test_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> data_iterator
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font>
data_iterator dbegin,
data_iterator dend
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>true</font>, dbegin, dend<font face='Lucida Console'>)</font>;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>test_one_step_calls;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>input_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> data,
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> labels
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// The reason these two loops don't initialize their counter variables but
</font> <font color='#009900'>// instead use class members is so we can include the state of the loops in the
</font> <font color='#009900'>// stuff written by sync_to_disk()
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>;
epoch_iteration <font color='#5555FF'>&lt;</font> max_num_epochs <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> learning_rate <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> min_learning_rate;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>epoch_iteration<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std::chrono;
last_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>clear_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>; epoch_pos <font color='#5555FF'>&lt;</font> data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> learning_rate <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> min_learning_rate; epoch_pos <font color='#5555FF'>+</font><font color='#5555FF'>=</font> mini_batch_size<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font> now_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>now_time<font color='#5555FF'>-</font>last_time <font color='#5555FF'>&gt;</font> <font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>20</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
last_time <font color='#5555FF'>=</font> now_time;
<font color='#0000FF'>auto</font> iter <font color='#5555FF'>=</font> epoch_iteration <font color='#5555FF'>+</font> epoch_pos<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>epoch: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>iter<font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>epoch_pos,
data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>epoch_pos<font color='#5555FF'>+</font>mini_batch_size,data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,
labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>epoch_pos<font face='Lucida Console'>)</font>;
<b>}</b>
epoch_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Capitalize the E in Epoch so it's easy to grep out the lines that
</font> <font color='#009900'>// are for full epoch status statements.
</font> std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Epoch: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>epoch_iteration<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// if we modified the network at all then be sure to sync the final result.
</font> <font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font color='#979000'>true</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>input_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> data
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> has_unsupervised_loss <font color='#5555FF'>=</font> std::is_same<font color='#5555FF'>&lt;</font>no_label_type, training_label_type<font color='#5555FF'>&gt;</font>::value;
<font color='#BB00BB'>static_assert</font><font face='Lucida Console'>(</font>has_unsupervised_loss,
"<font color='#CC0000'>You can only call this version of train() when using an unsupervised loss.</font>"<font face='Lucida Console'>)</font>;
<font color='#009900'>// The reason these two loops don't initialize their counter variables but
</font> <font color='#009900'>// instead use class members is so we can include the state of the loops in the
</font> <font color='#009900'>// stuff written by sync_to_disk()
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>;
epoch_iteration <font color='#5555FF'>&lt;</font> max_num_epochs <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> learning_rate <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> min_learning_rate;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>epoch_iteration<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std::chrono;
last_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>clear_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>; epoch_pos <font color='#5555FF'>&lt;</font> data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> learning_rate <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> min_learning_rate; epoch_pos <font color='#5555FF'>+</font><font color='#5555FF'>=</font> mini_batch_size<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font> now_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>now_time<font color='#5555FF'>-</font>last_time <font color='#5555FF'>&gt;</font> <font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>20</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
last_time <font color='#5555FF'>=</font> now_time;
<font color='#0000FF'>auto</font> iter <font color='#5555FF'>=</font> epoch_iteration <font color='#5555FF'>+</font> epoch_pos<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>epoch: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>iter<font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>epoch_pos,
data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>epoch_pos<font color='#5555FF'>+</font>mini_batch_size,data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
epoch_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Capitalize the E in Epoch so it's easy to grep out the lines that
</font> <font color='#009900'>// are for full epoch status statements.
</font> std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Epoch: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>epoch_iteration<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// if we modified the network at all then be sure to sync the final result.
</font> <font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font color='#979000'>true</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_synchronization_file'></a>set_synchronization_file</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> filename,
std::chrono::seconds time_between_syncs_ <font color='#5555FF'>=</font> std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>15</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font>
<b>{</b>
last_sync_time <font color='#5555FF'>=</font> std::chrono::system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
sync_filename <font color='#5555FF'>=</font> filename;
time_between_syncs <font color='#5555FF'>=</font> time_between_syncs_;
<font color='#009900'>// check if the sync file already exists, if it does we should load it.
</font> std::ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font><font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, std::ios::binary<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>fin<font face='Lucida Console'>)</font>
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font>, fin<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> <b><a name='get_synchronization_file'></a>get_synchronization_file</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> sync_filename;
<b>}</b>
<font color='#0000FF'><u>double</u></font> <b><a name='get_average_loss'></a>get_average_loss</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> rs.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>double</u></font> <b><a name='get_average_test_loss'></a>get_average_test_loss</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> rs_test.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='clear_average_loss'></a>clear_average_loss</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
rs.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate'></a>set_learning_rate</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>double</u></font> lr
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>lr <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>learning_rate <font color='#5555FF'>!</font><font color='#5555FF'>=</font> lr<font face='Lucida Console'>)</font>
<b>{</b>
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
previous_loss_values.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
test_previous_loss_values.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
learning_rate <font color='#5555FF'>=</font> lr;
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>double</u></font> <b><a name='get_learning_rate'></a>get_learning_rate</b><font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> learning_rate;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_min_learning_rate'></a>set_min_learning_rate</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>double</u></font> lr
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>lr <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
min_learning_rate <font color='#5555FF'>=</font> lr;
<b>}</b>
<font color='#0000FF'><u>double</u></font> <b><a name='get_min_learning_rate'></a>get_min_learning_rate</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> min_learning_rate;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> EXP<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate_schedule'></a>set_learning_rate_schedule</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> matrix_exp<font color='#5555FF'>&lt;</font>EXP<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> schedule
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>schedule<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font><font color='#BB00BB'>schedule</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>,<font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>set_min_learning_rate</font><font face='Lucida Console'>(</font><font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>schedule<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>set_learning_rate_shrink_factor</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
lr_schedule <font color='#5555FF'>=</font> matrix_cast<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font color='#BB00BB'>reshape_to_column_vector</font><font face='Lucida Console'>(</font>schedule<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
lr_schedule_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<b>}</b>
<font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> <b><a name='get_learning_rate_schedule'></a>get_learning_rate_schedule</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> lr_schedule;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_iterations_without_progress_threshold'></a>set_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> thresh
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
iter_without_progress_thresh <font color='#5555FF'>=</font> thresh;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_iterations_without_progress_threshold'></a>get_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> iter_without_progress_thresh;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_steps_without_progress'></a>get_steps_without_progress</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> steps_without_progress;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_test_iterations_without_progress_threshold'></a>set_test_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> thresh
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
test_iter_without_progress_thresh <font color='#5555FF'>=</font> thresh;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_test_iterations_without_progress_threshold'></a>get_test_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> test_iter_without_progress_thresh;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_test_steps_without_progress'></a>get_test_steps_without_progress</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> test_steps_without_progress;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate_shrink_factor'></a>set_learning_rate_shrink_factor</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>double</u></font> shrink
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'>&lt;</font> shrink <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> shrink <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
learning_rate_shrink <font color='#5555FF'>=</font> shrink;
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<b>}</b>
<font color='#0000FF'><u>double</u></font> <b><a name='get_learning_rate_shrink_factor'></a>get_learning_rate_shrink_factor</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> learning_rate_shrink;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_train_one_step_calls'></a>get_train_one_step_calls</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> train_one_step_calls;
<b>}</b>
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_test_one_step_calls'></a>get_test_one_step_calls</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> test_one_step_calls;
<b>}</b>
<font color='#0000FF'>private</font>:
<font color='#0000FF'><u>void</u></font> <b><a name='record_test_loss'></a>record_test_loss</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> loss<font face='Lucida Console'>)</font>
<b>{</b>
test_previous_loss_values.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>is_finite</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
rs_test.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>;
<font color='#009900'>// discard really old loss values.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> test_iter_without_progress_thresh<font face='Lucida Console'>)</font>
test_previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='record_loss'></a>record_loss</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> loss<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// This kind of budgeting causes our gradient checking to use a fixed amount of
</font> <font color='#009900'>// computational resources, regardless of the size of iter_without_progress_thresh.
</font> gradient_check_budget <font color='#5555FF'>+</font><font color='#5555FF'>=</font> <font color='#979000'>200</font>;
rs.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>;
previous_loss_values.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>;
<font color='#009900'>// discard really old loss values.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> iter_without_progress_thresh<font face='Lucida Console'>)</font>
previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// separately keep another loss history until disk sync
</font> <font color='#009900'>// (but only if disk sync is enabled)
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>sync_filename.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> T<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>double</u></font> <b><a name='compute_parameter_gradients'></a>compute_parameter_gradients</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> device, job_t<font color='#5555FF'>&amp;</font> next_job, <font color='#0000FF'>const</font> T<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.have_data[device]<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> dev <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>devices[device];
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>dev.device_id<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.test_only<font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>next_job.t[device], next_job.labels[device].<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>else</font>
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>next_job.t[device], next_job.labels[device].<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
<font color='#0000FF'>return</font> <font color='#979000'>0</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'><u>double</u></font> <b><a name='compute_parameter_gradients'></a>compute_parameter_gradients</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> device, job_t<font color='#5555FF'>&amp;</font> next_job, <font color='#0000FF'>const</font> no_label_type<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.have_data[device]<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> dev <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>devices[device];
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>dev.device_id<font face='Lucida Console'>)</font>;
no_label_type pick_which_run_update;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.test_only<font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>next_job.t[device]<font face='Lucida Console'>)</font>;
<font color='#0000FF'>else</font>
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>next_job.t[device]<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
<font color='#0000FF'>return</font> <font color='#979000'>0</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='update_parameters'></a>update_parameters</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> device<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> dev <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>devices[device];
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>dev.device_id<font face='Lucida Console'>)</font>;
dev.net.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font><font color='#BB00BB'>make_sstack</font><font face='Lucida Console'>(</font>dev.solvers<font face='Lucida Console'>)</font>, learning_rate<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='thread'></a>thread</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
training_label_type pick_which_run_update;
job_t next_job;
std::vector<font color='#5555FF'>&lt;</font>dlib::future<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>losses</font><font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'>&lt;</font>tt::multi_device_tensor_averager<font color='#5555FF'>&gt;</font> averagers;
<font color='#009900'>// An array of all the parameter tensors in the first network. We will
</font> <font color='#009900'>// periodically copy these tensors to all the other devices to make sure the
</font> <font color='#009900'>// different GPUs don't go out of sync.
</font> std::vector<font color='#5555FF'>&lt;</font>tensor<font color='#5555FF'>*</font><font color='#5555FF'>&gt;</font> reference_params;
<font color='#BB00BB'>visit_layer_parameters</font><font face='Lucida Console'>(</font>devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net, [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font>tensor<font color='#5555FF'>&amp;</font> t<font face='Lucida Console'>)</font> <b>{</b> reference_params.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#5555FF'>&amp;</font>t<font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
<font color='#009900'>// If no external thread pools vector was passed, then create one that will
</font> <font color='#009900'>// be automatically destructed as soon as the dnn_trainer object goes out of
</font> <font color='#009900'>// scope.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>thread_pools<font face='Lucida Console'>)</font>
thread_pools <font color='#5555FF'>=</font> std::make_shared<font color='#5555FF'>&lt;</font>threads<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> tp <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>thread_pools;
<font color='#009900'>// We make separate thread pools with just one thread in them because we want
</font> <font color='#009900'>// to make sure each device is always executed on the same thread. We care
</font> <font color='#009900'>// about this because there are thread_local context variables for some cuda
</font> <font color='#009900'>// components and they get allocated for each combination of thread and device.
</font> <font color='#009900'>// So if we make sure the same device always uses the same thread this will
</font> <font color='#009900'>// reduce the number of contexts we allocate from num_devices*num_devices to
</font> <font color='#009900'>// just num_devices.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>tp.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
tp.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'>&lt;</font>thread_pool<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
main_iteration_counter <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>job_pipe.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>next_job<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.test_only<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// compute the testing loss
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>add_task_by_value</font><font face='Lucida Console'>(</font>[<font color='#5555FF'>&amp;</font>,i]<font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&amp;</font> loss<font face='Lucida Console'>)</font><b>{</b> loss <font color='#5555FF'>=</font> <font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>i, next_job, pick_which_run_update<font face='Lucida Console'>)</font>; <b>}</b>, losses[i]<font face='Lucida Console'>)</font>;
<font color='#009900'>// aggregate loss values from all the network computations.
</font> <font color='#0000FF'><u>double</u></font> theloss <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> loss : losses<font face='Lucida Console'>)</font>
theloss <font color='#5555FF'>+</font><font color='#5555FF'>=</font> loss.<font color='#BB00BB'>get</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>record_test_loss</font><font face='Lucida Console'>(</font>theloss<font color='#5555FF'>/</font>losses.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Check if we should shrink the learning rate based on how the test
</font> <font color='#009900'>// error has been doing lately.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>learning_rate_shrink <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease</font><font face='Lucida Console'>(</font>test_previous_loss_values<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_steps_without_progress <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> test_iter_without_progress_thresh<font face='Lucida Console'>)</font>
<b>{</b>
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease_robust</font><font face='Lucida Console'>(</font>test_previous_loss_values<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_steps_without_progress <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> test_iter_without_progress_thresh<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// optimization has flattened out, so drop the learning rate.
</font> learning_rate <font color='#5555FF'>=</font> learning_rate_shrink<font color='#5555FF'>*</font>learning_rate;
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#009900'>// Empty out some of the previous loss values so that test_steps_without_progress
</font> <font color='#009900'>// will decrease below test_iter_without_progress_thresh.
</font> <font color='#BB00BB'>drop_some_test_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<b>}</b>
<font color='#0000FF'>continue</font>;
<b>}</b>
updated_net_since_last_sync <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>main_iteration_counter;
<font color='#009900'>// Call compute_parameter_gradients() and update_parameters() but pick the
</font> <font color='#009900'>// right version for unsupervised or supervised training based on the type
</font> <font color='#009900'>// of training_label_type.
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>add_task_by_value</font><font face='Lucida Console'>(</font>[<font color='#5555FF'>&amp;</font>,i]<font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&amp;</font> loss<font face='Lucida Console'>)</font><b>{</b> loss <font color='#5555FF'>=</font> <font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>i, next_job, pick_which_run_update<font face='Lucida Console'>)</font>; <b>}</b>, losses[i]<font face='Lucida Console'>)</font>;
<font color='#009900'>// aggregate loss values from all the network computations.
</font> <font color='#0000FF'><u>double</u></font> theloss <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> loss : losses<font face='Lucida Console'>)</font>
theloss <font color='#5555FF'>+</font><font color='#5555FF'>=</font> loss.<font color='#BB00BB'>get</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>record_loss</font><font face='Lucida Console'>(</font>theloss<font color='#5555FF'>/</font>losses.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Now, if there is more than one active device we need to synchronize the
</font> <font color='#009900'>// gradient updates between devices. So we do that now.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// if this is the first iteration then we need to setup the averagers.
</font> <font color='#009900'>// We can't do this outside the loop because the tensors that get
</font> <font color='#009900'>// averaged need to be allocated to their devices before we call set()
</font> <font color='#009900'>// so that the averagers can determine how best to average them.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>averagers.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> sync_file_reloaded<font face='Lucida Console'>)</font>
<b>{</b>
averagers <font color='#5555FF'>=</font> std::vector<font color='#5555FF'>&lt;</font>tt::multi_device_tensor_averager<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>net_type::num_computational_layers<font face='Lucida Console'>)</font>;
<font color='#009900'>// setup the averagers to point to the tensors in the networks.
</font> std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>tensor<font color='#5555FF'>*</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>all_tensors</font><font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> all_tensors.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
all_tensors[i].<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>net_type::num_computational_layers<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>visit_layer_parameter_gradients</font><font face='Lucida Console'>(</font>devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net, [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j, tensor<font color='#5555FF'>&amp;</font> t<font face='Lucida Console'>)</font><b>{</b>
all_tensors[i][j] <font color='#5555FF'>=</font> <font color='#5555FF'>&amp;</font>t;
<b>}</b><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// Now set each averager to average the tensors at the same layer in each
</font> <font color='#009900'>// network.
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> net_type::num_computational_layers; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>tensor<font color='#5555FF'>*</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>all_tensors.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j <font color='#5555FF'>=</font> <font color='#979000'>0</font>; j <font color='#5555FF'>&lt;</font> all_tensors.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>j<font face='Lucida Console'>)</font>
<b>{</b>
temp[j] <font color='#5555FF'>=</font> all_tensors[j][i];
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>temp[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> temp[j]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
"<font color='#CC0000'>Make sure you don't modify the network structure </font>"
"<font color='#CC0000'>or number of parameters after constructing the trainer.</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ignore layers that don't have parameters
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>temp[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
averagers[i].<font color='#BB00BB'>set</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
<b>}</b>
sync_file_reloaded <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<b>}</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> d : devices<font face='Lucida Console'>)</font>
cuda::<font color='#BB00BB'>device_synchronize</font><font face='Lucida Console'>(</font>d<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>device_id<font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> avg : averagers<font face='Lucida Console'>)</font>
avg.<font color='#BB00BB'>average</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// Now apply all the updates to each device.
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>add_task_by_value</font><font face='Lucida Console'>(</font>[<font color='#5555FF'>&amp;</font>,i]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.have_data[i]<font face='Lucida Console'>)</font> <font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>i<font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
<font color='#009900'>// and wait for the updates to all happen.
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font><font color='#BB00BB'>wait_for_all_tasks</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Every now and then force all the parameters to be the same just to make
</font> <font color='#009900'>// sure they aren't drifting apart due to any non-deterministic behavior on
</font> <font color='#009900'>// the GPU. It's also important to do this on the first iteration because
</font> <font color='#009900'>// the different networks may be initialized differently when tensor data
</font> <font color='#009900'>// is first passed through them. So this code block deals with these
</font> <font color='#009900'>// issues.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>1</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> main_iteration_counter<font color='#5555FF'>%</font><font color='#979000'>2000</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>1</font>; i <font color='#5555FF'>&lt;</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>visit_layer_parameters</font><font face='Lucida Console'>(</font>devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net, [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j, tensor<font color='#5555FF'>&amp;</font> t<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>memcpy</font><font face='Lucida Console'>(</font>t, <font color='#5555FF'>*</font>reference_params[j]<font face='Lucida Console'>)</font>;
<b>}</b><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#009900'>// If we have been running for a while then check if the loss is still
</font> <font color='#009900'>// dropping. If it isn't then we will reduce the learning rate. Note that we
</font> <font color='#009900'>// have a "budget" that prevents us from calling
</font> <font color='#009900'>// count_steps_without_decrease() every iteration. We do this because
</font> <font color='#009900'>// it can be expensive to compute when previous_loss_values is large.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>gradient_check_budget <font color='#5555FF'>&gt;</font> iter_without_progress_thresh <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> learning_rate_shrink <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
gradient_check_budget <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease</font><font face='Lucida Console'>(</font>previous_loss_values<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>steps_without_progress <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> iter_without_progress_thresh<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Double check that we aren't seeing decrease. This second check
</font> <font color='#009900'>// discards the top 10% largest values and checks again. We do
</font> <font color='#009900'>// this because sometimes a mini-batch might be bad and cause the
</font> <font color='#009900'>// loss to suddenly jump up, making count_steps_without_decrease()
</font> <font color='#009900'>// return a large number. But if we discard the top 10% of the
</font> <font color='#009900'>// values in previous_loss_values then we are robust to that kind
</font> <font color='#009900'>// of noise. Another way of looking at it, if the reason
</font> <font color='#009900'>// count_steps_without_decrease() returns a large value is only
</font> <font color='#009900'>// because the most recent loss values have suddenly been large,
</font> <font color='#009900'>// then we shouldn't stop or lower the learning rate. We should
</font> <font color='#009900'>// keep going until whatever disturbance we hit is damped down.
</font> steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease_robust</font><font face='Lucida Console'>(</font>previous_loss_values<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>steps_without_progress <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> iter_without_progress_thresh<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// optimization has flattened out, so drop the learning rate.
</font> learning_rate <font color='#5555FF'>=</font> learning_rate_shrink<font color='#5555FF'>*</font>learning_rate;
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#009900'>// Empty out some of the previous loss values so that steps_without_progress
</font> <font color='#009900'>// will decrease below iter_without_progress_thresh.
</font> <font color='#BB00BB'>drop_some_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<b>}</b>
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#009900'>// or use the learning rate schedule if we have one.
</font> <b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>lr_schedule_pos <font color='#5555FF'>&lt;</font> lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
learning_rate <font color='#5555FF'>=</font> <font color='#BB00BB'>lr_schedule</font><font face='Lucida Console'>(</font>lr_schedule_pos<font color='#5555FF'>+</font><font color='#5555FF'>+</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>else</font>
learning_rate <font color='#5555FF'>=</font> <font color='#BB00BB'>lr_schedule</font><font face='Lucida Console'>(</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font color='#979000'>0.99</font>;
<b>}</b>
<b>}</b>
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>...<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// If an exception happens then permanently disable the trainer object.
</font> job_pipe.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
std::lock_guard<font color='#5555FF'>&lt;</font>std::mutex<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>lock</font><font face='Lucida Console'>(</font>eptr_mutex<font face='Lucida Console'>)</font>;
eptr <font color='#5555FF'>=</font> std::<font color='#BB00BB'>current_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='wait_for_thread_to_pause'></a>wait_for_thread_to_pause</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
job_pipe.<font color='#BB00BB'>wait_for_num_blocked_dequeues</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>long</u></font> string_pad <font color='#5555FF'>=</font> <font color='#979000'>11</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>long</u></font> epoch_string_pad <font color='#5555FF'>=</font> <font color='#979000'>4</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>long</u></font> lr_string_pad <font color='#5555FF'>=</font> <font color='#979000'>4</font>;
<font color='#0000FF'><u>void</u></font> <b><a name='init'></a>init</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
max_num_epochs <font color='#5555FF'>=</font> <font color='#979000'>10000</font>;
mini_batch_size <font color='#5555FF'>=</font> <font color='#979000'>128</font>;
verbose <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
learning_rate <font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>2</font>;
min_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font>;
iter_without_progress_thresh <font color='#5555FF'>=</font> <font color='#979000'>2000</font>;
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
test_iter_without_progress_thresh <font color='#5555FF'>=</font> <font color='#979000'>500</font>;
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
learning_rate_shrink <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
epoch_iteration <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
epoch_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
train_one_step_calls <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
test_one_step_calls <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
gradient_check_budget <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
lr_schedule_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
main_iteration_counter <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
main_iteration_counter_at_last_disk_sync <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
prob_loss_increasing_thresh_default_value <font color='#5555FF'>=</font> <font color='#979000'>0.99</font>;
prob_loss_increasing_thresh_max_value <font color='#5555FF'>=</font> <font color='#979000'>0.99999</font>;
prob_loss_increasing_thresh <font color='#5555FF'>=</font> prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
sync_file_reloaded <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
previous_loss_values_dump_amount <font color='#5555FF'>=</font> <font color='#979000'>400</font>;
test_previous_loss_values_dump_amount <font color='#5555FF'>=</font> <font color='#979000'>100</font>;
rs_test <font color='#5555FF'>=</font> running_stats_decayed<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>start</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// serialize and deserialize are private because we hold net by reference so
</font> <font color='#009900'>// allowing someone to serialize this training object is weird and will likely
</font> <font color='#009900'>// result in user errors. However, we use these functions as part of the automatic
</font> <font color='#009900'>// sync code in this object.
</font> <font color='#0000FF'>friend</font> <font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_trainer<font color='#5555FF'>&amp;</font> item, std::ostream<font color='#5555FF'>&amp;</font> out<font face='Lucida Console'>)</font>
<b>{</b>
item.<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>int</u></font> version <font color='#5555FF'>=</font> <font color='#979000'>13</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>version, out<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>size_t</u></font> nl <font color='#5555FF'>=</font> dnn_trainer::num_layers;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>nl, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.rs, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.rs_test, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.previous_loss_values, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.max_num_epochs, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.mini_batch_size, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.verbose, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.net, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>solvers, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.learning_rate.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.min_learning_rate, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.iter_without_progress_thresh.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.steps_without_progress.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.learning_rate_shrink.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.epoch_iteration, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.epoch_pos, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.train_one_step_calls, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_one_step_calls, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.lr_schedule, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.lr_schedule_pos, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_iter_without_progress_thresh.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_steps_without_progress.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_dump_amount, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values_dump_amount, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_to_keep_until_disk_sync, out<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>friend</font> <font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>dnn_trainer<font color='#5555FF'>&amp;</font> item, std::istream<font color='#5555FF'>&amp;</font> in<font face='Lucida Console'>)</font>
<b>{</b>
item.<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>int</u></font> version <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>version, in<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>version <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>13</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>throw</font> <font color='#BB00BB'>serialization_error</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>Unexpected version found while deserializing dlib::dnn_trainer.</font>"<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>size_t</u></font> num_layers <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>num_layers, in<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>num_layers <font color='#5555FF'>!</font><font color='#5555FF'>=</font> dnn_trainer::num_layers<font face='Lucida Console'>)</font>
<b>{</b>
std::ostringstream sout;
sout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Error deserializing dlib::dnn_trainer. The saved sync file is for a network with </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
sout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>a different number of layers. We expected the number of layers to be </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> dnn_trainer::num_layers <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> but</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
sout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>instead the file contains </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> num_layers <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> layers.</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
<font color='#0000FF'>throw</font> <font color='#BB00BB'>serialization_error</font><font face='Lucida Console'>(</font>sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>double</u></font> dtemp; <font color='#0000FF'><u>long</u></font> ltemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.rs, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.rs_test, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.previous_loss_values, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.max_num_epochs, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.mini_batch_size, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.verbose, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.net, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>solvers, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>dtemp, in<font face='Lucida Console'>)</font>; item.learning_rate <font color='#5555FF'>=</font> dtemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.min_learning_rate, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.iter_without_progress_thresh <font color='#5555FF'>=</font> ltemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.steps_without_progress <font color='#5555FF'>=</font> ltemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>dtemp, in<font face='Lucida Console'>)</font>; item.learning_rate_shrink <font color='#5555FF'>=</font> dtemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.epoch_iteration, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.epoch_pos, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.train_one_step_calls, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.test_one_step_calls, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.lr_schedule, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.lr_schedule_pos, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.test_iter_without_progress_thresh <font color='#5555FF'>=</font> ltemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.test_steps_without_progress <font color='#5555FF'>=</font> ltemp;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_dump_amount, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values_dump_amount, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_to_keep_until_disk_sync, in<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>item.devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> prev_dev <font color='#5555FF'>=</font> dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// initialize all the other device networks and solver objects
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>1</font>; i <font color='#5555FF'>&lt;</font> item.devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Switch to this device so that any tensor objects that get allocated when
</font> <font color='#009900'>// we copy this stuff happen on this device.
</font> dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>item.devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>device_id<font face='Lucida Console'>)</font>;
item.devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>solvers <font color='#5555FF'>=</font> item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>solvers;
item.devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net <font color='#5555FF'>=</font> item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net;
<b>}</b>
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>prev_dev<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#009900'>// Empty out some of the previous loss values so that steps_without_progress will decrease below iter_without_progress_thresh.
</font> <font color='#0000FF'><u>void</u></font> <b><a name='drop_some_previous_loss_values'></a>drop_some_previous_loss_values</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> cnt <font color='#5555FF'>=</font> <font color='#979000'>0</font>; cnt <font color='#5555FF'>&lt;</font> previous_loss_values_dump_amount <font color='#5555FF'>+</font> iter_without_progress_thresh <font color='#5555FF'>/</font> <font color='#979000'>10</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>cnt<font face='Lucida Console'>)</font>
previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// Empty out some of the previous test loss values so that test_steps_without_progress will decrease below test_iter_without_progress_thresh.
</font> <font color='#0000FF'><u>void</u></font> <b><a name='drop_some_test_previous_loss_values'></a>drop_some_test_previous_loss_values</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> cnt <font color='#5555FF'>=</font> <font color='#979000'>0</font>; cnt <font color='#5555FF'>&lt;</font> test_previous_loss_values_dump_amount <font color='#5555FF'>+</font> test_iter_without_progress_thresh <font color='#5555FF'>/</font> <font color='#979000'>10</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>cnt<font face='Lucida Console'>)</font>
test_previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='sync_to_disk'></a>sync_to_disk</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>bool</u></font> do_it_now <font color='#5555FF'>=</font> <font color='#979000'>false</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// don't sync anything if we haven't updated the network since the last sync
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>updated_net_since_last_sync<font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#009900'>// If the sync file isn't set then don't do anything.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>sync_filename.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font>;
<font color='#009900'>// Only sync if it has been long enough since the last sync or we are being
</font> <font color='#009900'>// explicitly forced to do it.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>std::chrono::system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>-</font> last_sync_time <font color='#5555FF'>&gt;</font> time_between_syncs <font color='#5555FF'>|</font><font color='#5555FF'>|</font>
do_it_now<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// compact network before saving to disk.
</font> <font color='#0000FF'>this</font><font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// if the loss has actually been going up since the last time we saved our
</font> <font color='#009900'>// state to disk then something has probably gone wrong in the
</font> <font color='#009900'>// optimization. So in this case we do the opposite and recall the
</font> <font color='#009900'>// previously saved state in the hopes that the problem won't reoccur.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>loss_increased_since_last_disk_sync</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
std::ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font><font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, std::ios::binary<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font>, fin<font face='Lucida Console'>)</font>;
sync_file_reloaded <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Loss has been increasing, reloading saved state from </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
<font color='#009900'>// Are we repeatedly hitting our head against the wall? If so, then we
</font> <font color='#009900'>// might be better off giving up at this learning rate, and trying a
</font> <font color='#009900'>// lower one instead.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>prob_loss_increasing_thresh <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> prob_loss_increasing_thresh_max_value<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>(and while at it, also shrinking the learning rate)</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
learning_rate <font color='#5555FF'>=</font> learning_rate_shrink <font color='#5555FF'>*</font> learning_rate;
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#BB00BB'>drop_some_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>drop_some_test_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
<font color='#0000FF'>const</font> std::string filename <font color='#5555FF'>=</font> <font color='#BB00BB'>oldest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>filename<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#5555FF'>*</font><font color='#0000FF'>this</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Saved state to </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> filename <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
<b>}</b>
last_sync_time <font color='#5555FF'>=</font> std::chrono::system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
main_iteration_counter_at_last_disk_sync <font color='#5555FF'>=</font> main_iteration_counter;
updated_net_since_last_sync <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<b>}</b>
<b>}</b>
std::string <b><a name='newest_syncfile'></a>newest_syncfile</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> <font color='#BB00BB'>select_newest_file</font><font face='Lucida Console'>(</font>sync_filename, sync_filename <font color='#5555FF'>+</font> "<font color='#CC0000'>_</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
std::string <b><a name='oldest_syncfile'></a>oldest_syncfile</b> <font face='Lucida Console'>(</font>
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> <font color='#BB00BB'>select_oldest_file</font><font face='Lucida Console'>(</font>sync_filename, sync_filename <font color='#5555FF'>+</font> "<font color='#CC0000'>_</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>bool</u></font> <b><a name='loss_increased_since_last_disk_sync'></a>loss_increased_since_last_disk_sync</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'><u>size_t</u></font> gradient_updates_since_last_sync <font color='#5555FF'>=</font> main_iteration_counter <font color='#5555FF'>-</font> main_iteration_counter_at_last_disk_sync;
<font color='#009900'>// if we haven't synced anything to disk yet then return false.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>std::<font color='#BB00BB'>ifstream</font><font face='Lucida Console'>(</font><font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, std::ios::binary<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> <font color='#979000'>false</font>;
<font color='#009900'>// Now look at the data since a little before the last disk sync. We will
</font> <font color='#009900'>// check if the loss is getting better or worse.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>2</font> <font color='#5555FF'>*</font> gradient_updates_since_last_sync<font face='Lucida Console'>)</font>
previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Always retry if there are any nan or inf values
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> x : previous_loss_values_to_keep_until_disk_sync<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>std::<font color='#BB00BB'>isnan</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> std::<font color='#BB00BB'>isinf</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> <font color='#979000'>true</font>;
<b>}</b>
<font color='#009900'>// if we haven't seen much data yet then just say false.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>gradient_updates_since_last_sync <font color='#5555FF'>&lt;</font> <font color='#979000'>30</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> <font color='#979000'>false</font>;
<font color='#009900'>// if the loss is very likely to be increasing then return true
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> prob1 <font color='#5555FF'>=</font> <font color='#BB00BB'>probability_values_are_increasing</font><font face='Lucida Console'>(</font>previous_loss_values_to_keep_until_disk_sync<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> prob2 <font color='#5555FF'>=</font> <font color='#BB00BB'>probability_values_are_increasing_robust</font><font face='Lucida Console'>(</font>previous_loss_values_to_keep_until_disk_sync<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>std::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font>prob1, prob2<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> prob_loss_increasing_thresh<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Exponentially decay the threshold towards 1 so that if we keep finding
</font> <font color='#009900'>// the loss to be increasing over and over we will make the test
</font> <font color='#009900'>// progressively harder and harder until it fails, therefore ensuring we
</font> <font color='#009900'>// can't get stuck reloading from a previous state over and over.
</font> prob_loss_increasing_thresh <font color='#5555FF'>=</font> std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>
<font color='#979000'>0.1</font><font color='#5555FF'>*</font>prob_loss_increasing_thresh <font color='#5555FF'>+</font> <font color='#979000'>0.9</font><font color='#5555FF'>*</font><font color='#979000'>1</font>,
prob_loss_increasing_thresh_max_value
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> <font color='#979000'>true</font>;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
<font color='#009900'>// decay back to the default threshold
</font> prob_loss_increasing_thresh <font color='#5555FF'>=</font> std::<font color='#BB00BB'>pow</font><font face='Lucida Console'>(</font>prob_loss_increasing_thresh, <font color='#979000'>10.0</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// but don't decay below the default value
</font> prob_loss_increasing_thresh <font color='#5555FF'>=</font> std::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font>prob_loss_increasing_thresh, prob_loss_increasing_thresh_default_value<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> <font color='#979000'>false</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>struct</font> <b><a name='clone_net'></a>clone_net</b><b>{</b><b>}</b>;
<font color='#009900'>// per device state. All the containers have the same number of objects in them.
</font> <font color='#0000FF'>struct</font> <b><a name='device_data'></a>device_data</b>
<b>{</b>
<b><a name='device_data'></a>device_data</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font> device_id_,
net_type<font color='#5555FF'>&amp;</font> net_,
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&amp;</font> solver_
<font face='Lucida Console'>)</font> : device_id<font face='Lucida Console'>(</font>device_id_<font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font>, solvers<font face='Lucida Console'>(</font>num_computational_layers, solver_<font face='Lucida Console'>)</font> <b>{</b><b>}</b>
<b><a name='device_data'></a>device_data</b><font face='Lucida Console'>(</font>
<font color='#0000FF'><u>int</u></font> device_id_,
net_type<font color='#5555FF'>&amp;</font> net_,
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&amp;</font> solver_,
clone_net
<font face='Lucida Console'>)</font> : device_id<font face='Lucida Console'>(</font>device_id_<font face='Lucida Console'>)</font>, net_copy<font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'>&lt;</font>net_type<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font><font color='#5555FF'>*</font>net_copy<font face='Lucida Console'>)</font>, solvers<font face='Lucida Console'>(</font>num_computational_layers, solver_<font face='Lucida Console'>)</font> <b>{</b><b>}</b>
<font color='#0000FF'><u>int</u></font> device_id;
std::shared_ptr<font color='#5555FF'>&lt;</font>net_type<font color='#5555FF'>&gt;</font> net_copy;
net_type<font color='#5555FF'>&amp;</font> net;
std::vector<font color='#5555FF'>&lt;</font>solver_type<font color='#5555FF'>&gt;</font> solvers;
<b>}</b>;
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> data_iterator,
<font color='#0000FF'>typename</font> label_iterator
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='send_job'></a>send_job</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>bool</u></font> test_only,
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>propagate_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>size_t</u></font> num <font color='#5555FF'>=</font> std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>size_t</u></font> devs <font color='#5555FF'>=</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
job.t.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>;
job.labels.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>;
job.have_data.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>;
job.test_only <font color='#5555FF'>=</font> test_only;
<font color='#009900'>// chop the data into devs blocks, each of about block_size elements.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> block_size <font color='#5555FF'>=</font> num <font color='#5555FF'>/</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> prev_dev <font color='#5555FF'>=</font> dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>double</u></font> j <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> devs; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>device_id<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> start <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>round</font><font face='Lucida Console'>(</font>j<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> stop <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>round</font><font face='Lucida Console'>(</font>j <font color='#5555FF'>+</font> block_size<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>start <font color='#5555FF'>&lt;</font> stop<font face='Lucida Console'>)</font>
<b>{</b>
devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>net.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>dbegin<font color='#5555FF'>+</font>start, dbegin<font color='#5555FF'>+</font>stop, job.t[i]<font face='Lucida Console'>)</font>;
job.labels[i].<font color='#BB00BB'>assign</font><font face='Lucida Console'>(</font>lbegin<font color='#5555FF'>+</font>start, lbegin<font color='#5555FF'>+</font>stop<font face='Lucida Console'>)</font>;
job.have_data[i] <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
job.have_data[i] <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<b>}</b>
j <font color='#5555FF'>+</font><font color='#5555FF'>=</font> block_size;
<b>}</b>
<font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>fabs</font><font face='Lucida Console'>(</font>j <font color='#5555FF'>-</font> num<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>10</font><font face='Lucida Console'>)</font>;
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>prev_dev<font face='Lucida Console'>)</font>;
job_pipe.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>job<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> data_iterator
<font color='#5555FF'>&gt;</font>
<font color='#0000FF'><u>void</u></font> <b><a name='send_job'></a>send_job</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'><u>bool</u></font> test_only,
data_iterator dbegin,
data_iterator dend
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>typename</font> std::vector<font color='#5555FF'>&lt;</font>training_label_type<font color='#5555FF'>&gt;</font>::iterator nothing;
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font>test_only, dbegin, dend, nothing<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='print_progress'></a>print_progress</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>steps without apparent progress: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> steps_without_progress;
<font color='#0000FF'>else</font>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>steps without apparent progress: train=</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> steps_without_progress <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>, test=</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> test_steps_without_progress;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
std::ostringstream sout;
sout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>percent complete: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::fixed <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::<font color='#BB00BB'>setprecision</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#979000'>100.0</font><font color='#5555FF'>*</font>lr_schedule_pos<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>%</font>";
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> std::endl;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='print_periodic_verbose_status'></a>print_periodic_verbose_status</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std::chrono;
<font color='#0000FF'>auto</font> now_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>now_time<font color='#5555FF'>-</font>last_time <font color='#5555FF'>&gt;</font> <font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>40</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
last_time <font color='#5555FF'>=</font> now_time;
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>step#: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>train_one_step_calls<font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>"
<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>train loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>test loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_test_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>";
<b>}</b>
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>clear_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<b>}</b>
std::vector<font color='#5555FF'>&lt;</font>std::shared_ptr<font color='#5555FF'>&lt;</font>device_data<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> devices;
dlib::pipe<font color='#5555FF'>&lt;</font>job_t<font color='#5555FF'>&gt;</font> job_pipe;
std::shared_ptr<font color='#5555FF'>&lt;</font>threads<font color='#5555FF'>&gt;</font> thread_pools;
job_t job;
running_stats<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> rs;
running_stats_decayed<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> rs_test;
std::deque<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> previous_loss_values;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> max_num_epochs;
<font color='#0000FF'><u>size_t</u></font> mini_batch_size;
<font color='#0000FF'><u>bool</u></font> verbose;
net_type<font color='#5555FF'>&amp;</font> net;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> learning_rate;
<font color='#0000FF'><u>double</u></font> min_learning_rate;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> iter_without_progress_thresh;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> steps_without_progress;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> test_iter_without_progress_thresh;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> test_steps_without_progress;
std::deque<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> test_previous_loss_values;
std::deque<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> previous_loss_values_to_keep_until_disk_sync;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> learning_rate_shrink;
std::chrono::time_point<font color='#5555FF'>&lt;</font>std::chrono::system_clock<font color='#5555FF'>&gt;</font> last_sync_time;
std::string sync_filename;
std::chrono::seconds time_between_syncs;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> epoch_iteration;
<font color='#0000FF'><u>size_t</u></font> epoch_pos;
std::chrono::time_point<font color='#5555FF'>&lt;</font>std::chrono::system_clock<font color='#5555FF'>&gt;</font> last_time;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> train_one_step_calls;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> test_one_step_calls;
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>&gt;</font> lr_schedule;
<font color='#0000FF'><u>long</u></font> lr_schedule_pos;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> gradient_check_budget;
std::exception_ptr eptr <font color='#5555FF'>=</font> nullptr;
<font color='#0000FF'>mutable</font> std::mutex eptr_mutex;
<font color='#0000FF'><u>void</u></font> <b><a name='propagate_exception'></a>propagate_exception</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
std::lock_guard<font color='#5555FF'>&lt;</font>std::mutex<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>lock</font><font face='Lucida Console'>(</font>eptr_mutex<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>eptr<font face='Lucida Console'>)</font>
std::<font color='#BB00BB'>rethrow_exception</font><font face='Lucida Console'>(</font>eptr<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// These 5 variables are not serialized
</font> <font color='#0000FF'><u>size_t</u></font> main_iteration_counter;
<font color='#0000FF'><u>size_t</u></font> main_iteration_counter_at_last_disk_sync;
<font color='#0000FF'><u>double</u></font> prob_loss_increasing_thresh_default_value;
<font color='#0000FF'><u>double</u></font> prob_loss_increasing_thresh_max_value;
<font color='#0000FF'><u>double</u></font> prob_loss_increasing_thresh;
std::atomic<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>bool</u></font><font color='#5555FF'>&gt;</font> updated_net_since_last_sync;
<font color='#0000FF'><u>bool</u></font> sync_file_reloaded;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> previous_loss_values_dump_amount;
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> test_previous_loss_values_dump_amount;
<b>}</b>;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
<font color='#0000FF'>typename</font> net_type,
<font color='#0000FF'>typename</font> solver_type
<font color='#5555FF'>&gt;</font>
std::ostream<font color='#5555FF'>&amp;</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font face='Lucida Console'>(</font>
std::ostream<font color='#5555FF'>&amp;</font> out,
dnn_trainer<font color='#5555FF'>&lt;</font>net_type,solver_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> trainer
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>using</font> std::endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>dnn_trainer details: \n</font>";
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> net_type::num_layers: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> net_type::num_layers <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#009900'>// figure out how big the net is in MB.
</font> std::ostringstream sout;
net_type temp <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#009900'>// make a copy so that we can clean it without mutating the trainer's net.
</font> temp.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>temp, sout<font face='Lucida Console'>)</font>;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> net size: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font><font color='#979000'>1024.0</font><font color='#5555FF'>/</font><font color='#979000'>1024.0</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> MiB</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#009900'>// Don't include the loss params in the hash since we print them on the next line.
</font> <font color='#009900'>// They also aren't really part of the "architecture" of the network.
</font> out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> net architecture hash: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>md5</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>loss_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> get_train_one_step_calls(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_train_one_step_calls</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> synchronization file: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_synchronization_file</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> trainer.get_solvers()[0]: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_solvers</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>[<font color='#979000'>0</font>] <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> mini batch size: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_mini_batch_size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>auto</font> sched <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>get_learning_rate_schedule</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>sched.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> using explicit user-supplied learning rate schedule</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> learning rate: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> learning rate shrink factor: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_learning_rate_shrink_factor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> min learning rate: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_min_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> iterations without progress threshold: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
out <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> test iterations without progress threshold: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer.<font color='#BB00BB'>get_test_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>
<font color='#0000FF'>return</font> out;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNn_TRAINER_H_
</font>
</pre></body></html>