|
<html><head><title>dlib C++ Library - trainer.h</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// Copyright (C) 2015 Davis E. King ([email protected]) |
|
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license. |
|
</font><font color='#0000FF'>#ifndef</font> DLIB_DNn_TRAINER_H_ |
|
<font color='#0000FF'>#define</font> DLIB_DNn_TRAINER_H_ |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='trainer_abstract.h.html'>trainer_abstract.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='core.h.html'>core.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='solvers.h.html'>solvers.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../statistics.h.html'>../statistics.h</a>" |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>chrono<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>fstream<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>sstream<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../serialize.h.html'>../serialize.h</a>" |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../pipe.h.html'>../pipe.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../threads.h.html'>../threads.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../cuda/cuda_dlib.h.html'>../cuda/cuda_dlib.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../statistics/running_gradient.h.html'>../statistics/running_gradient.h</a>" |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>atomic<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>cstdio<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>set<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>future<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>exception<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>mutex<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../dir_nav.h.html'>../dir_nav.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../md5.h.html'>../md5.h</a>" |
|
|
|
<font color='#0000FF'>namespace</font> dlib |
|
<b>{</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>namespace</font> impl |
|
<b>{</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> training_label_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'>struct</font> <b><a name='dnn_job_t'></a>dnn_job_t</b> |
|
<b>{</b> |
|
<b><a name='dnn_job_t'></a>dnn_job_t</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>default</font>; |
|
<b><a name='dnn_job_t'></a>dnn_job_t</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_job_t<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
dnn_job_t<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_job_t<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
|
|
std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font><font color='#5555FF'>></font> labels; |
|
std::vector<font color='#5555FF'><</font>resizable_tensor<font color='#5555FF'>></font> t; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font><font color='#5555FF'>></font> have_data; <font color='#009900'>// have_data[i] is true if there is data in labels[i] and t[i]. |
|
</font> <font color='#0000FF'><u>bool</u></font> test_only <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> training_label_type<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='swap'></a>swap</b><font face='Lucida Console'>(</font>dnn_job_t<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> a, dnn_job_t<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> b<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
a.labels.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b.labels<font face='Lucida Console'>)</font>; |
|
a.t.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b.t<font face='Lucida Console'>)</font>; |
|
a.have_data.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b.have_data<font face='Lucida Console'>)</font>; |
|
std::<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>a.test_only,b.test_only<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>enum</font> <font color='#0000FF'>class</font> <b><a name='force_flush_to_disk'></a>force_flush_to_disk</b> <b>{</b> |
|
no <font color='#5555FF'>=</font> <font color='#979000'>0</font>, |
|
yes <font color='#5555FF'>=</font> <font color='#979000'>1</font> |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> net_type, |
|
<font color='#0000FF'>typename</font> solver_type <font color='#5555FF'>=</font> sgd |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'>class</font> <b><a name='dnn_trainer'></a>dnn_trainer</b> : <font color='#0000FF'>private</font> threaded_object |
|
<b>{</b> |
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='static_assert'></a>static_assert</b><font face='Lucida Console'>(</font>is_loss_layer_type<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font>::value, |
|
"<font color='#CC0000'>The last layer in a network must be a loss layer.</font>"<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> net_type::training_label_type training_label_type; |
|
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> net_type::input_type input_type; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>size_t</u></font> num_computational_layers <font color='#5555FF'>=</font> net_type::num_computational_layers; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>size_t</u></font> num_layers <font color='#5555FF'>=</font> net_type::num_layers; |
|
<font color='#0000FF'>using</font> threads <font color='#5555FF'>=</font> std::vector<font color='#5555FF'><</font>std::shared_ptr<font color='#5555FF'><</font>thread_pool<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>private</font>: |
|
<font color='#0000FF'>typedef</font> impl::dnn_job_t<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font> job_t; |
|
<font color='#0000FF'>public</font>: |
|
|
|
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_trainer<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
dnn_trainer<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'>=</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_trainer<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#0000FF'>delete</font>; |
|
|
|
<font color='#0000FF'>explicit</font> <b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font>net_type<font color='#5555FF'>&</font> net_<font face='Lucida Console'>)</font> : job_pipe<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
solver_type default_solver; |
|
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'><</font>device_data<font color='#5555FF'>></font><font face='Lucida Console'>(</font>dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, net, default_solver<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>init</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font> |
|
net_type<font color='#5555FF'>&</font> net_, |
|
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&</font> solver_ |
|
<font face='Lucida Console'>)</font> : job_pipe<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'><</font>device_data<font color='#5555FF'>></font><font face='Lucida Console'>(</font>dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, net, solver_<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>init</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font> |
|
net_type<font color='#5555FF'>&</font> net_, |
|
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&</font> solver_, |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font><font color='#5555FF'>></font><font color='#5555FF'>&</font> cuda_extra_devices, |
|
std::shared_ptr<font color='#5555FF'><</font>threads<font color='#5555FF'>></font> thread_pools_ <font color='#5555FF'>=</font> std::shared_ptr<font color='#5555FF'><</font>threads<font color='#5555FF'>></font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font> : job_pipe<font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>, thread_pools<font face='Lucida Console'>(</font>thread_pools_<font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'><</font>device_data<font color='#5555FF'>></font><font face='Lucida Console'>(</font>dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, net, solver_<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> total_devices <font color='#5555FF'>=</font> dlib::cuda::<font color='#BB00BB'>get_num_devices</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Make device contexts for the extra device ids but be careful to avoid any |
|
</font> <font color='#009900'>// duplicate ids. |
|
</font> std::set<font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font><font color='#5555FF'>></font> <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>cuda_extra_devices.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, cuda_extra_devices.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
temp.<font color='#BB00BB'>erase</font><font face='Lucida Console'>(</font>devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>device_id<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> id : temp<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'><</font><font color='#5555FF'>=</font> id <font color='#5555FF'>&</font><font color='#5555FF'>&</font> id <font color='#5555FF'><</font> total_devices, "<font color='#CC0000'>Invalid CUDA device id given to dnn_trainer.</font>"<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Switch to this device so that any tensor objects that get allocated when |
|
</font> <font color='#009900'>// we create the device context happen on this device. |
|
</font> dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>id<font face='Lucida Console'>)</font>; |
|
devices.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'><</font>device_data<font color='#5555FF'>></font><font face='Lucida Console'>(</font>id, net, solver_, <font color='#BB00BB'>clone_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#009900'>// Set the current device back to what it was before this constructor was |
|
</font> <font color='#009900'>// called. |
|
</font> dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>device_id<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>init</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
~<b><a name='dnn_trainer'></a>dnn_trainer</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
job_pipe.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>stop</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>wait</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
net_type<font color='#5555FF'>&</font> <b><a name='get_net'></a>get_net</b> <font face='Lucida Console'>(</font> |
|
force_flush_to_disk force_flush <font color='#5555FF'>=</font> force_flush_to_disk::yes |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font>force_flush <font color='#5555FF'>=</font><font color='#5555FF'>=</font> force_flush_to_disk::yes<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>propagate_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> net; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_mini_batch_size'></a>get_mini_batch_size</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> <b>{</b> <font color='#0000FF'>return</font> mini_batch_size; <b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_mini_batch_size'></a>set_mini_batch_size</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> batch_size |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>batch_size <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
mini_batch_size <font color='#5555FF'>=</font> batch_size; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_max_num_epochs'></a>get_max_num_epochs</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> <b>{</b> <font color='#0000FF'>return</font> max_num_epochs; <b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_max_num_epochs'></a>set_max_num_epochs</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>num <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
max_num_epochs <font color='#5555FF'>=</font> num; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='be_verbose'></a>be_verbose</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
verbose <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='be_quiet'></a>be_quiet</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
verbose <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>solver_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> <b><a name='get_solvers'></a>get_solvers</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>propagate_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>solvers; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>input_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> labels |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> data_iterator, |
|
<font color='#0000FF'>typename</font> label_iterator |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font> |
|
data_iterator dbegin, |
|
data_iterator dend, |
|
label_iterator lbegin |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, dbegin, dend, lbegin<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>train_one_step_calls; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>input_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> data |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> data_iterator |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='train_one_step'></a>train_one_step</b> <font face='Lucida Console'>(</font> |
|
data_iterator dbegin, |
|
data_iterator dend |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, dbegin, dend<font face='Lucida Console'>)</font>; |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>train_one_step_calls; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>input_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> labels |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>test_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> data_iterator, |
|
<font color='#0000FF'>typename</font> label_iterator |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font> |
|
data_iterator dbegin, |
|
data_iterator dend, |
|
label_iterator lbegin |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>true</font>, dbegin, dend, lbegin<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>test_one_step_calls; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>input_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> data |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>test_one_step</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, data.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> data_iterator |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='test_one_step'></a>test_one_step</b> <font face='Lucida Console'>(</font> |
|
data_iterator dbegin, |
|
data_iterator dend |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>print_periodic_verbose_status</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>true</font>, dbegin, dend<font face='Lucida Console'>)</font>; |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>test_one_step_calls; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>input_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> data, |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> labels |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// The reason these two loops don't initialize their counter variables but |
|
</font> <font color='#009900'>// instead use class members is so we can include the state of the loops in the |
|
</font> <font color='#009900'>// stuff written by sync_to_disk() |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>; |
|
epoch_iteration <font color='#5555FF'><</font> max_num_epochs <font color='#5555FF'>&</font><font color='#5555FF'>&</font> learning_rate <font color='#5555FF'>></font><font color='#5555FF'>=</font> min_learning_rate; |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>epoch_iteration<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std::chrono; |
|
last_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>clear_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>; epoch_pos <font color='#5555FF'><</font> data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> learning_rate <font color='#5555FF'>></font><font color='#5555FF'>=</font> min_learning_rate; epoch_pos <font color='#5555FF'>+</font><font color='#5555FF'>=</font> mini_batch_size<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font> now_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>now_time<font color='#5555FF'>-</font>last_time <font color='#5555FF'>></font> <font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>20</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
last_time <font color='#5555FF'>=</font> now_time; |
|
<font color='#0000FF'>auto</font> iter <font color='#5555FF'>=</font> epoch_iteration <font color='#5555FF'>+</font> epoch_pos<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>epoch: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>iter<font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>epoch_pos, |
|
data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>epoch_pos<font color='#5555FF'>+</font>mini_batch_size,data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, |
|
labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>epoch_pos<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
epoch_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Capitalize the E in Epoch so it's easy to grep out the lines that |
|
</font> <font color='#009900'>// are for full epoch status statements. |
|
</font> std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Epoch: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>epoch_iteration<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// if we modified the network at all then be sure to sync the final result. |
|
</font> <font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font color='#979000'>true</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>input_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> data |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> has_unsupervised_loss <font color='#5555FF'>=</font> std::is_same<font color='#5555FF'><</font>no_label_type, training_label_type<font color='#5555FF'>></font>::value; |
|
<font color='#BB00BB'>static_assert</font><font face='Lucida Console'>(</font>has_unsupervised_loss, |
|
"<font color='#CC0000'>You can only call this version of train() when using an unsupervised loss.</font>"<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// The reason these two loops don't initialize their counter variables but |
|
</font> <font color='#009900'>// instead use class members is so we can include the state of the loops in the |
|
</font> <font color='#009900'>// stuff written by sync_to_disk() |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>; |
|
epoch_iteration <font color='#5555FF'><</font> max_num_epochs <font color='#5555FF'>&</font><font color='#5555FF'>&</font> learning_rate <font color='#5555FF'>></font><font color='#5555FF'>=</font> min_learning_rate; |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>epoch_iteration<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std::chrono; |
|
last_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>clear_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font>; epoch_pos <font color='#5555FF'><</font> data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> learning_rate <font color='#5555FF'>></font><font color='#5555FF'>=</font> min_learning_rate; epoch_pos <font color='#5555FF'>+</font><font color='#5555FF'>=</font> mini_batch_size<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font> now_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>now_time<font color='#5555FF'>-</font>last_time <font color='#5555FF'>></font> <font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>20</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
last_time <font color='#5555FF'>=</font> now_time; |
|
<font color='#0000FF'>auto</font> iter <font color='#5555FF'>=</font> epoch_iteration <font color='#5555FF'>+</font> epoch_pos<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>epoch: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>iter<font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font><font color='#979000'>false</font>, data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>epoch_pos, |
|
data.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>epoch_pos<font color='#5555FF'>+</font>mini_batch_size,data.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
epoch_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Capitalize the E in Epoch so it's easy to grep out the lines that |
|
</font> <font color='#009900'>// are for full epoch status statements. |
|
</font> std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Epoch: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>epoch_iteration<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// if we modified the network at all then be sure to sync the final result. |
|
</font> <font color='#BB00BB'>sync_to_disk</font><font face='Lucida Console'>(</font><font color='#979000'>true</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_synchronization_file'></a>set_synchronization_file</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> filename, |
|
std::chrono::seconds time_between_syncs_ <font color='#5555FF'>=</font> std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>15</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
last_sync_time <font color='#5555FF'>=</font> std::chrono::system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
sync_filename <font color='#5555FF'>=</font> filename; |
|
time_between_syncs <font color='#5555FF'>=</font> time_between_syncs_; |
|
|
|
<font color='#009900'>// check if the sync file already exists, if it does we should load it. |
|
</font> std::ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font><font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, std::ios::binary<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>fin<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font>, fin<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> <b><a name='get_synchronization_file'></a>get_synchronization_file</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> sync_filename; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> <b><a name='get_average_loss'></a>get_average_loss</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> rs.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> <b><a name='get_average_test_loss'></a>get_average_test_loss</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> rs_test.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='clear_average_loss'></a>clear_average_loss</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
rs.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate'></a>set_learning_rate</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>double</u></font> lr |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>lr <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>learning_rate <font color='#5555FF'>!</font><font color='#5555FF'>=</font> lr<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
previous_loss_values.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
test_previous_loss_values.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
learning_rate <font color='#5555FF'>=</font> lr; |
|
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> <b><a name='get_learning_rate'></a>get_learning_rate</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> learning_rate; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_min_learning_rate'></a>set_min_learning_rate</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>double</u></font> lr |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>lr <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
min_learning_rate <font color='#5555FF'>=</font> lr; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> <b><a name='get_min_learning_rate'></a>get_min_learning_rate</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> min_learning_rate; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> EXP<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate_schedule'></a>set_learning_rate_schedule</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> matrix_exp<font color='#5555FF'><</font>EXP<font color='#5555FF'>></font><font color='#5555FF'>&</font> schedule |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>schedule<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font><font color='#BB00BB'>schedule</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>,<font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>set_min_learning_rate</font><font face='Lucida Console'>(</font><font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>schedule<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>set_learning_rate_shrink_factor</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
lr_schedule <font color='#5555FF'>=</font> matrix_cast<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font><font color='#BB00BB'>reshape_to_column_vector</font><font face='Lucida Console'>(</font>schedule<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
lr_schedule_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>&</font> <b><a name='get_learning_rate_schedule'></a>get_learning_rate_schedule</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> lr_schedule; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_iterations_without_progress_threshold'></a>set_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> thresh |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
iter_without_progress_thresh <font color='#5555FF'>=</font> thresh; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_iterations_without_progress_threshold'></a>get_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> iter_without_progress_thresh; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_steps_without_progress'></a>get_steps_without_progress</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> steps_without_progress; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_test_iterations_without_progress_threshold'></a>set_test_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> thresh |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
test_iter_without_progress_thresh <font color='#5555FF'>=</font> thresh; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_test_iterations_without_progress_threshold'></a>get_test_iterations_without_progress_threshold</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> test_iter_without_progress_thresh; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_test_steps_without_progress'></a>get_test_steps_without_progress</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> test_steps_without_progress; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='set_learning_rate_shrink_factor'></a>set_learning_rate_shrink_factor</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>double</u></font> shrink |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font><font color='#979000'>0</font> <font color='#5555FF'><</font> shrink <font color='#5555FF'>&</font><font color='#5555FF'>&</font> shrink <font color='#5555FF'><</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
lr_schedule.<font color='#BB00BB'>set_size</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>; |
|
learning_rate_shrink <font color='#5555FF'>=</font> shrink; |
|
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> <b><a name='get_learning_rate_shrink_factor'></a>get_learning_rate_shrink_factor</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> learning_rate_shrink; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_train_one_step_calls'></a>get_train_one_step_calls</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> train_one_step_calls; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_test_one_step_calls'></a>get_test_one_step_calls</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> test_one_step_calls; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>private</font>: |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='record_test_loss'></a>record_test_loss</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> loss<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
test_previous_loss_values.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>is_finite</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
rs_test.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// discard really old loss values. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> test_iter_without_progress_thresh<font face='Lucida Console'>)</font> |
|
test_previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='record_loss'></a>record_loss</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> loss<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// This kind of budgeting causes our gradient checking to use a fixed amount of |
|
</font> <font color='#009900'>// computational resources, regardless of the size of iter_without_progress_thresh. |
|
</font> gradient_check_budget <font color='#5555FF'>+</font><font color='#5555FF'>=</font> <font color='#979000'>200</font>; |
|
|
|
rs.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>; |
|
previous_loss_values.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// discard really old loss values. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> iter_without_progress_thresh<font face='Lucida Console'>)</font> |
|
previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// separately keep another loss history until disk sync |
|
</font> <font color='#009900'>// (but only if disk sync is enabled) |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>sync_filename.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>loss<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> T<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>double</u></font> <b><a name='compute_parameter_gradients'></a>compute_parameter_gradients</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> device, job_t<font color='#5555FF'>&</font> next_job, <font color='#0000FF'>const</font> T<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.have_data[device]<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> dev <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>devices[device]; |
|
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>dev.device_id<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.test_only<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>next_job.t[device], next_job.labels[device].<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>else</font> |
|
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>next_job.t[device], next_job.labels[device].<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font color='#979000'>0</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> <b><a name='compute_parameter_gradients'></a>compute_parameter_gradients</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> device, job_t<font color='#5555FF'>&</font> next_job, <font color='#0000FF'>const</font> no_label_type<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.have_data[device]<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> dev <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>devices[device]; |
|
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>dev.device_id<font face='Lucida Console'>)</font>; |
|
no_label_type pick_which_run_update; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.test_only<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>next_job.t[device]<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>else</font> |
|
<font color='#0000FF'>return</font> dev.net.<font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>next_job.t[device]<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font color='#979000'>0</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='update_parameters'></a>update_parameters</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> device<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> dev <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>devices[device]; |
|
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>dev.device_id<font face='Lucida Console'>)</font>; |
|
dev.net.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font><font color='#BB00BB'>make_sstack</font><font face='Lucida Console'>(</font>dev.solvers<font face='Lucida Console'>)</font>, learning_rate<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='thread'></a>thread</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> |
|
<b>{</b> |
|
training_label_type pick_which_run_update; |
|
job_t next_job; |
|
|
|
std::vector<font color='#5555FF'><</font>dlib::future<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> <font color='#BB00BB'>losses</font><font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
std::vector<font color='#5555FF'><</font>tt::multi_device_tensor_averager<font color='#5555FF'>></font> averagers; |
|
<font color='#009900'>// An array of all the parameter tensors in the first network. We will |
|
</font> <font color='#009900'>// periodically copy these tensors to all the other devices to make sure the |
|
</font> <font color='#009900'>// different GPUs don't go out of sync. |
|
</font> std::vector<font color='#5555FF'><</font>tensor<font color='#5555FF'>*</font><font color='#5555FF'>></font> reference_params; |
|
<font color='#BB00BB'>visit_layer_parameters</font><font face='Lucida Console'>(</font>devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>net, [<font color='#5555FF'>&</font>]<font face='Lucida Console'>(</font>tensor<font color='#5555FF'>&</font> t<font face='Lucida Console'>)</font> <b>{</b> reference_params.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#5555FF'>&</font>t<font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// If no external thread pools vector was passed, then create one that will |
|
</font> <font color='#009900'>// be automatically destructed as soon as the dnn_trainer object goes out of |
|
</font> <font color='#009900'>// scope. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>thread_pools<font face='Lucida Console'>)</font> |
|
thread_pools <font color='#5555FF'>=</font> std::make_shared<font color='#5555FF'><</font>threads<font color='#5555FF'>></font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font> tp <font color='#5555FF'>=</font> <font color='#5555FF'>*</font>thread_pools; |
|
|
|
<font color='#009900'>// We make separate thread pools with just one thread in them because we want |
|
</font> <font color='#009900'>// to make sure each device is always executed on the same thread. We care |
|
</font> <font color='#009900'>// about this because there are thread_local context variables for some cuda |
|
</font> <font color='#009900'>// components and they get allocated for each combination of thread and device. |
|
</font> <font color='#009900'>// So if we make sure the same device always uses the same thread this will |
|
</font> <font color='#009900'>// reduce the number of contexts we allocate from num_devices*num_devices to |
|
</font> <font color='#009900'>// just num_devices. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>tp.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
tp.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'><</font>thread_pool<font color='#5555FF'>></font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
main_iteration_counter <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>job_pipe.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>next_job<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.test_only<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// compute the testing loss |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>add_task_by_value</font><font face='Lucida Console'>(</font>[<font color='#5555FF'>&</font>,i]<font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&</font> loss<font face='Lucida Console'>)</font><b>{</b> loss <font color='#5555FF'>=</font> <font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>i, next_job, pick_which_run_update<font face='Lucida Console'>)</font>; <b>}</b>, losses[i]<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// aggregate loss values from all the network computations. |
|
</font> <font color='#0000FF'><u>double</u></font> theloss <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> loss : losses<font face='Lucida Console'>)</font> |
|
theloss <font color='#5555FF'>+</font><font color='#5555FF'>=</font> loss.<font color='#BB00BB'>get</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>record_test_loss</font><font face='Lucida Console'>(</font>theloss<font color='#5555FF'>/</font>losses.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Check if we should shrink the learning rate based on how the test |
|
</font> <font color='#009900'>// error has been doing lately. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>learning_rate_shrink <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease</font><font face='Lucida Console'>(</font>test_previous_loss_values<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_steps_without_progress <font color='#5555FF'>></font><font color='#5555FF'>=</font> test_iter_without_progress_thresh<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease_robust</font><font face='Lucida Console'>(</font>test_previous_loss_values<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_steps_without_progress <font color='#5555FF'>></font><font color='#5555FF'>=</font> test_iter_without_progress_thresh<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// optimization has flattened out, so drop the learning rate. |
|
</font> learning_rate <font color='#5555FF'>=</font> learning_rate_shrink<font color='#5555FF'>*</font>learning_rate; |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#009900'>// Empty out some of the previous loss values so that test_steps_without_progress |
|
</font> <font color='#009900'>// will decrease below test_iter_without_progress_thresh. |
|
</font> <font color='#BB00BB'>drop_some_test_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>continue</font>; |
|
<b>}</b> |
|
|
|
updated_net_since_last_sync <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>main_iteration_counter; |
|
<font color='#009900'>// Call compute_parameter_gradients() and update_parameters() but pick the |
|
</font> <font color='#009900'>// right version for unsupervised or supervised training based on the type |
|
</font> <font color='#009900'>// of training_label_type. |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>add_task_by_value</font><font face='Lucida Console'>(</font>[<font color='#5555FF'>&</font>,i]<font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&</font> loss<font face='Lucida Console'>)</font><b>{</b> loss <font color='#5555FF'>=</font> <font color='#BB00BB'>compute_parameter_gradients</font><font face='Lucida Console'>(</font>i, next_job, pick_which_run_update<font face='Lucida Console'>)</font>; <b>}</b>, losses[i]<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// aggregate loss values from all the network computations. |
|
</font> <font color='#0000FF'><u>double</u></font> theloss <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> loss : losses<font face='Lucida Console'>)</font> |
|
theloss <font color='#5555FF'>+</font><font color='#5555FF'>=</font> loss.<font color='#BB00BB'>get</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>record_loss</font><font face='Lucida Console'>(</font>theloss<font color='#5555FF'>/</font>losses.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Now, if there is more than one active device we need to synchronize the |
|
</font> <font color='#009900'>// gradient updates between devices. So we do that now. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// if this is the first iteration then we need to setup the averagers. |
|
</font> <font color='#009900'>// We can't do this outside the loop because the tensors that get |
|
</font> <font color='#009900'>// averaged need to be allocated to their devices before we call set() |
|
</font> <font color='#009900'>// so that the averagers can determine how best to average them. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>averagers.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> sync_file_reloaded<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
averagers <font color='#5555FF'>=</font> std::vector<font color='#5555FF'><</font>tt::multi_device_tensor_averager<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net_type::num_computational_layers<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// setup the averagers to point to the tensors in the networks. |
|
</font> std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>tensor<font color='#5555FF'>*</font><font color='#5555FF'>></font><font color='#5555FF'>></font> <font color='#BB00BB'>all_tensors</font><font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> all_tensors.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
all_tensors[i].<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>net_type::num_computational_layers<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>visit_layer_parameter_gradients</font><font face='Lucida Console'>(</font>devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>net, [<font color='#5555FF'>&</font>]<font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j, tensor<font color='#5555FF'>&</font> t<font face='Lucida Console'>)</font><b>{</b> |
|
all_tensors[i][j] <font color='#5555FF'>=</font> <font color='#5555FF'>&</font>t; |
|
<b>}</b><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#009900'>// Now set each averager to average the tensors at the same layer in each |
|
</font> <font color='#009900'>// network. |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> net_type::num_computational_layers; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>tensor<font color='#5555FF'>*</font><font color='#5555FF'>></font> <font color='#BB00BB'>temp</font><font face='Lucida Console'>(</font>all_tensors.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j <font color='#5555FF'>=</font> <font color='#979000'>0</font>; j <font color='#5555FF'><</font> all_tensors.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>j<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
temp[j] <font color='#5555FF'>=</font> all_tensors[j][i]; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>temp[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> temp[j]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
"<font color='#CC0000'>Make sure you don't modify the network structure </font>" |
|
"<font color='#CC0000'>or number of parameters after constructing the trainer.</font>"<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#009900'>// ignore layers that don't have parameters |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>temp[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
averagers[i].<font color='#BB00BB'>set</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
sync_file_reloaded <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> d : devices<font face='Lucida Console'>)</font> |
|
cuda::<font color='#BB00BB'>device_synchronize</font><font face='Lucida Console'>(</font>d<font color='#5555FF'>-</font><font color='#5555FF'>></font>device_id<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> avg : averagers<font face='Lucida Console'>)</font> |
|
avg.<font color='#BB00BB'>average</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
|
|
<font color='#009900'>// Now apply all the updates to each device. |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>add_task_by_value</font><font face='Lucida Console'>(</font>[<font color='#5555FF'>&</font>,i]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>next_job.have_data[i]<font face='Lucida Console'>)</font> <font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>i<font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// and wait for the updates to all happen. |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
tp[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font><font color='#BB00BB'>wait_for_all_tasks</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#009900'>// Every now and then force all the parameters to be the same just to make |
|
</font> <font color='#009900'>// sure they aren't drifting apart due to any non-deterministic behavior on |
|
</font> <font color='#009900'>// the GPU. It's also important to do this on the first iteration because |
|
</font> <font color='#009900'>// the different networks may be initialized differently when tensor data |
|
</font> <font color='#009900'>// is first passed through them. So this code block deals with these |
|
</font> <font color='#009900'>// issues. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>1</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> main_iteration_counter<font color='#5555FF'>%</font><font color='#979000'>2000</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>1</font>; i <font color='#5555FF'><</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>visit_layer_parameters</font><font face='Lucida Console'>(</font>devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>net, [<font color='#5555FF'>&</font>]<font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j, tensor<font color='#5555FF'>&</font> t<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>memcpy</font><font face='Lucida Console'>(</font>t, <font color='#5555FF'>*</font>reference_params[j]<font face='Lucida Console'>)</font>; |
|
<b>}</b><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// If we have been running for a while then check if the loss is still |
|
</font> <font color='#009900'>// dropping. If it isn't then we will reduce the learning rate. Note that we |
|
</font> <font color='#009900'>// have a "budget" that prevents us from calling |
|
</font> <font color='#009900'>// count_steps_without_decrease() every iteration. We do this because |
|
</font> <font color='#009900'>// it can be expensive to compute when previous_loss_values is large. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>gradient_check_budget <font color='#5555FF'>></font> iter_without_progress_thresh <font color='#5555FF'>&</font><font color='#5555FF'>&</font> learning_rate_shrink <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
gradient_check_budget <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease</font><font face='Lucida Console'>(</font>previous_loss_values<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>steps_without_progress <font color='#5555FF'>></font><font color='#5555FF'>=</font> iter_without_progress_thresh<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Double check that we aren't seeing decrease. This second check |
|
</font> <font color='#009900'>// discards the top 10% largest values and checks again. We do |
|
</font> <font color='#009900'>// this because sometimes a mini-batch might be bad and cause the |
|
</font> <font color='#009900'>// loss to suddenly jump up, making count_steps_without_decrease() |
|
</font> <font color='#009900'>// return a large number. But if we discard the top 10% of the |
|
</font> <font color='#009900'>// values in previous_loss_values then we are robust to that kind |
|
</font> <font color='#009900'>// of noise. Another way of looking at it, if the reason |
|
</font> <font color='#009900'>// count_steps_without_decrease() returns a large value is only |
|
</font> <font color='#009900'>// because the most recent loss values have suddenly been large, |
|
</font> <font color='#009900'>// then we shouldn't stop or lower the learning rate. We should |
|
</font> <font color='#009900'>// keep going until whatever disturbance we hit is damped down. |
|
</font> steps_without_progress <font color='#5555FF'>=</font> <font color='#BB00BB'>count_steps_without_decrease_robust</font><font face='Lucida Console'>(</font>previous_loss_values<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>steps_without_progress <font color='#5555FF'>></font><font color='#5555FF'>=</font> iter_without_progress_thresh<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// optimization has flattened out, so drop the learning rate. |
|
</font> learning_rate <font color='#5555FF'>=</font> learning_rate_shrink<font color='#5555FF'>*</font>learning_rate; |
|
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#009900'>// Empty out some of the previous loss values so that steps_without_progress |
|
</font> <font color='#009900'>// will decrease below iter_without_progress_thresh. |
|
</font> <font color='#BB00BB'>drop_some_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#009900'>// or use the learning rate schedule if we have one. |
|
</font> <b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>lr_schedule_pos <font color='#5555FF'><</font> lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
learning_rate <font color='#5555FF'>=</font> <font color='#BB00BB'>lr_schedule</font><font face='Lucida Console'>(</font>lr_schedule_pos<font color='#5555FF'>+</font><font color='#5555FF'>+</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>else</font> |
|
learning_rate <font color='#5555FF'>=</font> <font color='#BB00BB'>lr_schedule</font><font face='Lucida Console'>(</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font color='#979000'>0.99</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>...<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// If an exception happens then permanently disable the trainer object. |
|
</font> job_pipe.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
std::lock_guard<font color='#5555FF'><</font>std::mutex<font color='#5555FF'>></font> <font color='#BB00BB'>lock</font><font face='Lucida Console'>(</font>eptr_mutex<font face='Lucida Console'>)</font>; |
|
eptr <font color='#5555FF'>=</font> std::<font color='#BB00BB'>current_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='wait_for_thread_to_pause'></a>wait_for_thread_to_pause</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
job_pipe.<font color='#BB00BB'>wait_for_num_blocked_dequeues</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>long</u></font> string_pad <font color='#5555FF'>=</font> <font color='#979000'>11</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>long</u></font> epoch_string_pad <font color='#5555FF'>=</font> <font color='#979000'>4</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>long</u></font> lr_string_pad <font color='#5555FF'>=</font> <font color='#979000'>4</font>; |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='init'></a>init</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
max_num_epochs <font color='#5555FF'>=</font> <font color='#979000'>10000</font>; |
|
mini_batch_size <font color='#5555FF'>=</font> <font color='#979000'>128</font>; |
|
verbose <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
learning_rate <font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>2</font>; |
|
min_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font>; |
|
iter_without_progress_thresh <font color='#5555FF'>=</font> <font color='#979000'>2000</font>; |
|
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
test_iter_without_progress_thresh <font color='#5555FF'>=</font> <font color='#979000'>500</font>; |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
learning_rate_shrink <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>; |
|
epoch_iteration <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
epoch_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
train_one_step_calls <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
test_one_step_calls <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
gradient_check_budget <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
lr_schedule_pos <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
main_iteration_counter <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
main_iteration_counter_at_last_disk_sync <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
prob_loss_increasing_thresh_default_value <font color='#5555FF'>=</font> <font color='#979000'>0.99</font>; |
|
prob_loss_increasing_thresh_max_value <font color='#5555FF'>=</font> <font color='#979000'>0.99999</font>; |
|
prob_loss_increasing_thresh <font color='#5555FF'>=</font> prob_loss_increasing_thresh_default_value; |
|
updated_net_since_last_sync <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
sync_file_reloaded <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
previous_loss_values_dump_amount <font color='#5555FF'>=</font> <font color='#979000'>400</font>; |
|
test_previous_loss_values_dump_amount <font color='#5555FF'>=</font> <font color='#979000'>100</font>; |
|
|
|
rs_test <font color='#5555FF'>=</font> running_stats_decayed<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>start</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// serialize and deserialize are private because we hold net by reference so |
|
</font> <font color='#009900'>// allowing someone to serialize this training object is weird and will likely |
|
</font> <font color='#009900'>// result in user errors. However, we use these functions as part of the automatic |
|
</font> <font color='#009900'>// sync code in this object. |
|
</font> <font color='#0000FF'>friend</font> <font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dnn_trainer<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
item.<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>int</u></font> version <font color='#5555FF'>=</font> <font color='#979000'>13</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>version, out<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'><u>size_t</u></font> nl <font color='#5555FF'>=</font> dnn_trainer::num_layers; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>nl, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.rs, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.rs_test, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.previous_loss_values, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.max_num_epochs, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.mini_batch_size, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.verbose, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.net, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>solvers, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.learning_rate.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.min_learning_rate, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.iter_without_progress_thresh.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.steps_without_progress.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.learning_rate_shrink.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.epoch_iteration, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.epoch_pos, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.train_one_step_calls, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_one_step_calls, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.lr_schedule, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.lr_schedule_pos, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_iter_without_progress_thresh.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_steps_without_progress.<font color='#BB00BB'>load</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_dump_amount, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values_dump_amount, out<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_to_keep_until_disk_sync, out<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>friend</font> <font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>dnn_trainer<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
item.<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>int</u></font> version <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>version, in<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>version <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>13</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>throw</font> <font color='#BB00BB'>serialization_error</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>Unexpected version found while deserializing dlib::dnn_trainer.</font>"<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'><u>size_t</u></font> num_layers <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>num_layers, in<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>num_layers <font color='#5555FF'>!</font><font color='#5555FF'>=</font> dnn_trainer::num_layers<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::ostringstream sout; |
|
sout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Error deserializing dlib::dnn_trainer. The saved sync file is for a network with </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
sout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>a different number of layers. We expected the number of layers to be </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> dnn_trainer::num_layers <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> but</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
sout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>instead the file contains </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_layers <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> layers.</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
<font color='#0000FF'>throw</font> <font color='#BB00BB'>serialization_error</font><font face='Lucida Console'>(</font>sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>double</u></font> dtemp; <font color='#0000FF'><u>long</u></font> ltemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.rs, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.rs_test, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.previous_loss_values, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.max_num_epochs, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.mini_batch_size, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.verbose, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.net, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>solvers, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>dtemp, in<font face='Lucida Console'>)</font>; item.learning_rate <font color='#5555FF'>=</font> dtemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.min_learning_rate, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.iter_without_progress_thresh <font color='#5555FF'>=</font> ltemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.steps_without_progress <font color='#5555FF'>=</font> ltemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>dtemp, in<font face='Lucida Console'>)</font>; item.learning_rate_shrink <font color='#5555FF'>=</font> dtemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.epoch_iteration, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.epoch_pos, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.train_one_step_calls, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.test_one_step_calls, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.lr_schedule, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.lr_schedule_pos, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.test_iter_without_progress_thresh <font color='#5555FF'>=</font> ltemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>ltemp, in<font face='Lucida Console'>)</font>; item.test_steps_without_progress <font color='#5555FF'>=</font> ltemp; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_dump_amount, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.test_previous_loss_values_dump_amount, in<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.previous_loss_values_to_keep_until_disk_sync, in<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>item.devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> prev_dev <font color='#5555FF'>=</font> dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// initialize all the other device networks and solver objects |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>1</font>; i <font color='#5555FF'><</font> item.devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Switch to this device so that any tensor objects that get allocated when |
|
</font> <font color='#009900'>// we copy this stuff happen on this device. |
|
</font> dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>item.devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>device_id<font face='Lucida Console'>)</font>; |
|
item.devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>solvers <font color='#5555FF'>=</font> item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>solvers; |
|
item.devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>net <font color='#5555FF'>=</font> item.devices[<font color='#979000'>0</font>]<font color='#5555FF'>-</font><font color='#5555FF'>></font>net; |
|
<b>}</b> |
|
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>prev_dev<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Empty out some of the previous loss values so that steps_without_progress will decrease below iter_without_progress_thresh. |
|
</font> <font color='#0000FF'><u>void</u></font> <b><a name='drop_some_previous_loss_values'></a>drop_some_previous_loss_values</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> cnt <font color='#5555FF'>=</font> <font color='#979000'>0</font>; cnt <font color='#5555FF'><</font> previous_loss_values_dump_amount <font color='#5555FF'>+</font> iter_without_progress_thresh <font color='#5555FF'>/</font> <font color='#979000'>10</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>cnt<font face='Lucida Console'>)</font> |
|
previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Empty out some of the previous test loss values so that test_steps_without_progress will decrease below test_iter_without_progress_thresh. |
|
</font> <font color='#0000FF'><u>void</u></font> <b><a name='drop_some_test_previous_loss_values'></a>drop_some_test_previous_loss_values</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> cnt <font color='#5555FF'>=</font> <font color='#979000'>0</font>; cnt <font color='#5555FF'><</font> test_previous_loss_values_dump_amount <font color='#5555FF'>+</font> test_iter_without_progress_thresh <font color='#5555FF'>/</font> <font color='#979000'>10</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>cnt<font face='Lucida Console'>)</font> |
|
test_previous_loss_values.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='sync_to_disk'></a>sync_to_disk</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>bool</u></font> do_it_now <font color='#5555FF'>=</font> <font color='#979000'>false</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// don't sync anything if we haven't updated the network since the last sync |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>updated_net_since_last_sync<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#009900'>// If the sync file isn't set then don't do anything. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>sync_filename.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font>; |
|
|
|
<font color='#009900'>// Only sync if it has been long enough since the last sync or we are being |
|
</font> <font color='#009900'>// explicitly forced to do it. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>std::chrono::system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>-</font> last_sync_time <font color='#5555FF'>></font> time_between_syncs <font color='#5555FF'>|</font><font color='#5555FF'>|</font> |
|
do_it_now<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>wait_for_thread_to_pause</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// compact network before saving to disk. |
|
</font> <font color='#0000FF'>this</font><font color='#5555FF'>-</font><font color='#5555FF'>></font>net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// if the loss has actually been going up since the last time we saved our |
|
</font> <font color='#009900'>// state to disk then something has probably gone wrong in the |
|
</font> <font color='#009900'>// optimization. So in this case we do the opposite and recall the |
|
</font> <font color='#009900'>// previously saved state in the hopes that the problem won't reoccur. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>loss_increased_since_last_disk_sync</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font><font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, std::ios::binary<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font><font color='#5555FF'>*</font><font color='#0000FF'>this</font>, fin<font face='Lucida Console'>)</font>; |
|
sync_file_reloaded <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Loss has been increasing, reloading saved state from </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
|
|
<font color='#009900'>// Are we repeatedly hitting our head against the wall? If so, then we |
|
</font> <font color='#009900'>// might be better off giving up at this learning rate, and trying a |
|
</font> <font color='#009900'>// lower one instead. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>prob_loss_increasing_thresh <font color='#5555FF'>></font><font color='#5555FF'>=</font> prob_loss_increasing_thresh_max_value<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>(and while at it, also shrinking the learning rate)</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
|
|
learning_rate <font color='#5555FF'>=</font> learning_rate_shrink <font color='#5555FF'>*</font> learning_rate; |
|
steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
test_steps_without_progress <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#BB00BB'>drop_some_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>drop_some_test_previous_loss_values</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
|
|
<font color='#0000FF'>const</font> std::string filename <font color='#5555FF'>=</font> <font color='#BB00BB'>oldest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>filename<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#5555FF'>*</font><font color='#0000FF'>this</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Saved state to </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> filename <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
<b>}</b> |
|
|
|
last_sync_time <font color='#5555FF'>=</font> std::chrono::system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
main_iteration_counter_at_last_disk_sync <font color='#5555FF'>=</font> main_iteration_counter; |
|
updated_net_since_last_sync <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
std::string <b><a name='newest_syncfile'></a>newest_syncfile</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font color='#BB00BB'>select_newest_file</font><font face='Lucida Console'>(</font>sync_filename, sync_filename <font color='#5555FF'>+</font> "<font color='#CC0000'>_</font>"<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
std::string <b><a name='oldest_syncfile'></a>oldest_syncfile</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font color='#BB00BB'>select_oldest_file</font><font face='Lucida Console'>(</font>sync_filename, sync_filename <font color='#5555FF'>+</font> "<font color='#CC0000'>_</font>"<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>bool</u></font> <b><a name='loss_increased_since_last_disk_sync'></a>loss_increased_since_last_disk_sync</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>size_t</u></font> gradient_updates_since_last_sync <font color='#5555FF'>=</font> main_iteration_counter <font color='#5555FF'>-</font> main_iteration_counter_at_last_disk_sync; |
|
|
|
<font color='#009900'>// if we haven't synced anything to disk yet then return false. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>std::<font color='#BB00BB'>ifstream</font><font face='Lucida Console'>(</font><font color='#BB00BB'>newest_syncfile</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, std::ios::binary<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; |
|
|
|
<font color='#009900'>// Now look at the data since a little before the last disk sync. We will |
|
</font> <font color='#009900'>// check if the loss is getting better or worse. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>2</font> <font color='#5555FF'>*</font> gradient_updates_since_last_sync<font face='Lucida Console'>)</font> |
|
previous_loss_values_to_keep_until_disk_sync.<font color='#BB00BB'>pop_front</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Always retry if there are any nan or inf values |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> x : previous_loss_values_to_keep_until_disk_sync<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>std::<font color='#BB00BB'>isnan</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> std::<font color='#BB00BB'>isinf</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// if we haven't seen much data yet then just say false. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>gradient_updates_since_last_sync <font color='#5555FF'><</font> <font color='#979000'>30</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; |
|
|
|
<font color='#009900'>// if the loss is very likely to be increasing then return true |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> prob1 <font color='#5555FF'>=</font> <font color='#BB00BB'>probability_values_are_increasing</font><font face='Lucida Console'>(</font>previous_loss_values_to_keep_until_disk_sync<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> prob2 <font color='#5555FF'>=</font> <font color='#BB00BB'>probability_values_are_increasing_robust</font><font face='Lucida Console'>(</font>previous_loss_values_to_keep_until_disk_sync<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>std::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font>prob1, prob2<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> prob_loss_increasing_thresh<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Exponentially decay the threshold towards 1 so that if we keep finding |
|
</font> <font color='#009900'>// the loss to be increasing over and over we will make the test |
|
</font> <font color='#009900'>// progressively harder and harder until it fails, therefore ensuring we |
|
</font> <font color='#009900'>// can't get stuck reloading from a previous state over and over. |
|
</font> prob_loss_increasing_thresh <font color='#5555FF'>=</font> std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font> |
|
<font color='#979000'>0.1</font><font color='#5555FF'>*</font>prob_loss_increasing_thresh <font color='#5555FF'>+</font> <font color='#979000'>0.9</font><font color='#5555FF'>*</font><font color='#979000'>1</font>, |
|
prob_loss_increasing_thresh_max_value |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
<font color='#009900'>// decay back to the default threshold |
|
</font> prob_loss_increasing_thresh <font color='#5555FF'>=</font> std::<font color='#BB00BB'>pow</font><font face='Lucida Console'>(</font>prob_loss_increasing_thresh, <font color='#979000'>10.0</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// but don't decay below the default value |
|
</font> prob_loss_increasing_thresh <font color='#5555FF'>=</font> std::<font color='#BB00BB'>max</font><font face='Lucida Console'>(</font>prob_loss_increasing_thresh, prob_loss_increasing_thresh_default_value<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
|
|
<font color='#0000FF'>struct</font> <b><a name='clone_net'></a>clone_net</b><b>{</b><b>}</b>; |
|
|
|
<font color='#009900'>// per device state. All the containers have the same number of objects in them. |
|
</font> <font color='#0000FF'>struct</font> <b><a name='device_data'></a>device_data</b> |
|
<b>{</b> |
|
<b><a name='device_data'></a>device_data</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font> device_id_, |
|
net_type<font color='#5555FF'>&</font> net_, |
|
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&</font> solver_ |
|
<font face='Lucida Console'>)</font> : device_id<font face='Lucida Console'>(</font>device_id_<font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font>, solvers<font face='Lucida Console'>(</font>num_computational_layers, solver_<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
<b><a name='device_data'></a>device_data</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>int</u></font> device_id_, |
|
net_type<font color='#5555FF'>&</font> net_, |
|
<font color='#0000FF'>const</font> solver_type<font color='#5555FF'>&</font> solver_, |
|
clone_net |
|
<font face='Lucida Console'>)</font> : device_id<font face='Lucida Console'>(</font>device_id_<font face='Lucida Console'>)</font>, net_copy<font face='Lucida Console'>(</font>std::make_shared<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font><font face='Lucida Console'>(</font>net_<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, net<font face='Lucida Console'>(</font><font color='#5555FF'>*</font>net_copy<font face='Lucida Console'>)</font>, solvers<font face='Lucida Console'>(</font>num_computational_layers, solver_<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
<font color='#0000FF'><u>int</u></font> device_id; |
|
std::shared_ptr<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font> net_copy; |
|
net_type<font color='#5555FF'>&</font> net; |
|
std::vector<font color='#5555FF'><</font>solver_type<font color='#5555FF'>></font> solvers; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> data_iterator, |
|
<font color='#0000FF'>typename</font> label_iterator |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='send_job'></a>send_job</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>bool</u></font> test_only, |
|
data_iterator dbegin, |
|
data_iterator dend, |
|
label_iterator lbegin |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>propagate_exception</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>size_t</u></font> num <font color='#5555FF'>=</font> std::<font color='#BB00BB'>distance</font><font face='Lucida Console'>(</font>dbegin, dend<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>size_t</u></font> devs <font color='#5555FF'>=</font> devices.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
job.t.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>; |
|
job.labels.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>; |
|
job.have_data.<font color='#BB00BB'>resize</font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>; |
|
job.test_only <font color='#5555FF'>=</font> test_only; |
|
|
|
<font color='#009900'>// chop the data into devs blocks, each of about block_size elements. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> block_size <font color='#5555FF'>=</font> num <font color='#5555FF'>/</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>devs<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> prev_dev <font color='#5555FF'>=</font> dlib::cuda::<font color='#BB00BB'>get_device</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'><u>double</u></font> j <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> devs; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>device_id<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> start <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>round</font><font face='Lucida Console'>(</font>j<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> stop <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>size_t</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>round</font><font face='Lucida Console'>(</font>j <font color='#5555FF'>+</font> block_size<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>start <font color='#5555FF'><</font> stop<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
devices[i]<font color='#5555FF'>-</font><font color='#5555FF'>></font>net.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>dbegin<font color='#5555FF'>+</font>start, dbegin<font color='#5555FF'>+</font>stop, job.t[i]<font face='Lucida Console'>)</font>; |
|
job.labels[i].<font color='#BB00BB'>assign</font><font face='Lucida Console'>(</font>lbegin<font color='#5555FF'>+</font>start, lbegin<font color='#5555FF'>+</font>stop<font face='Lucida Console'>)</font>; |
|
job.have_data[i] <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
job.have_data[i] <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<b>}</b> |
|
|
|
j <font color='#5555FF'>+</font><font color='#5555FF'>=</font> block_size; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>fabs</font><font face='Lucida Console'>(</font>j <font color='#5555FF'>-</font> num<font face='Lucida Console'>)</font> <font color='#5555FF'><</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>10</font><font face='Lucida Console'>)</font>; |
|
|
|
dlib::cuda::<font color='#BB00BB'>set_device</font><font face='Lucida Console'>(</font>prev_dev<font face='Lucida Console'>)</font>; |
|
job_pipe.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>job<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> data_iterator |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='send_job'></a>send_job</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'><u>bool</u></font> test_only, |
|
data_iterator dbegin, |
|
data_iterator dend |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>typename</font> std::vector<font color='#5555FF'><</font>training_label_type<font color='#5555FF'>></font>::iterator nothing; |
|
<font color='#BB00BB'>send_job</font><font face='Lucida Console'>(</font>test_only, dbegin, dend, nothing<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='print_progress'></a>print_progress</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>steps without apparent progress: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> steps_without_progress; |
|
<font color='#0000FF'>else</font> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>steps without apparent progress: train=</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> steps_without_progress <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>, test=</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> test_steps_without_progress; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
std::ostringstream sout; |
|
sout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>percent complete: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> std::fixed <font color='#5555FF'><</font><font color='#5555FF'><</font> std::<font color='#BB00BB'>setprecision</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#979000'>100.0</font><font color='#5555FF'>*</font>lr_schedule_pos<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>lr_schedule.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>%</font>"; |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> std::endl; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='print_periodic_verbose_status'></a>print_periodic_verbose_status</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>verbose<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std::chrono; |
|
<font color='#0000FF'>auto</font> now_time <font color='#5555FF'>=</font> system_clock::<font color='#BB00BB'>now</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>now_time<font color='#5555FF'>-</font>last_time <font color='#5555FF'>></font> <font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>40</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
last_time <font color='#5555FF'>=</font> now_time; |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>step#: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>train_one_step_calls<font face='Lucida Console'>)</font>,epoch_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" |
|
<font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>learning rate: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>learning_rate<font face='Lucida Console'>)</font>,lr_string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>test_previous_loss_values.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>average loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>train loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>test loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>rpad</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_average_test_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,string_pad<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>"; |
|
<b>}</b> |
|
<font color='#BB00BB'>print_progress</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>clear_average_loss</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>std::shared_ptr<font color='#5555FF'><</font>device_data<font color='#5555FF'>></font><font color='#5555FF'>></font> devices; |
|
dlib::pipe<font color='#5555FF'><</font>job_t<font color='#5555FF'>></font> job_pipe; |
|
std::shared_ptr<font color='#5555FF'><</font>threads<font color='#5555FF'>></font> thread_pools; |
|
job_t job; |
|
|
|
|
|
running_stats<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> rs; |
|
running_stats_decayed<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> rs_test; |
|
std::deque<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> previous_loss_values; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> max_num_epochs; |
|
<font color='#0000FF'><u>size_t</u></font> mini_batch_size; |
|
<font color='#0000FF'><u>bool</u></font> verbose; |
|
net_type<font color='#5555FF'>&</font> net; |
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> learning_rate; |
|
<font color='#0000FF'><u>double</u></font> min_learning_rate; |
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> iter_without_progress_thresh; |
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> steps_without_progress; |
|
|
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> test_iter_without_progress_thresh; |
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> test_steps_without_progress; |
|
std::deque<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> test_previous_loss_values; |
|
|
|
std::deque<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> previous_loss_values_to_keep_until_disk_sync; |
|
|
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> learning_rate_shrink; |
|
std::chrono::time_point<font color='#5555FF'><</font>std::chrono::system_clock<font color='#5555FF'>></font> last_sync_time; |
|
std::string sync_filename; |
|
std::chrono::seconds time_between_syncs; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> epoch_iteration; |
|
<font color='#0000FF'><u>size_t</u></font> epoch_pos; |
|
std::chrono::time_point<font color='#5555FF'><</font>std::chrono::system_clock<font color='#5555FF'>></font> last_time; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> train_one_step_calls; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <font color='#0000FF'><u>long</u></font> test_one_step_calls; |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font> lr_schedule; |
|
<font color='#0000FF'><u>long</u></font> lr_schedule_pos; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> gradient_check_budget; |
|
|
|
std::exception_ptr eptr <font color='#5555FF'>=</font> nullptr; |
|
<font color='#0000FF'>mutable</font> std::mutex eptr_mutex; |
|
<font color='#0000FF'><u>void</u></font> <b><a name='propagate_exception'></a>propagate_exception</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
std::lock_guard<font color='#5555FF'><</font>std::mutex<font color='#5555FF'>></font> <font color='#BB00BB'>lock</font><font face='Lucida Console'>(</font>eptr_mutex<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>eptr<font face='Lucida Console'>)</font> |
|
std::<font color='#BB00BB'>rethrow_exception</font><font face='Lucida Console'>(</font>eptr<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// These 5 variables are not serialized |
|
</font> <font color='#0000FF'><u>size_t</u></font> main_iteration_counter; |
|
<font color='#0000FF'><u>size_t</u></font> main_iteration_counter_at_last_disk_sync; |
|
<font color='#0000FF'><u>double</u></font> prob_loss_increasing_thresh_default_value; |
|
<font color='#0000FF'><u>double</u></font> prob_loss_increasing_thresh_max_value; |
|
<font color='#0000FF'><u>double</u></font> prob_loss_increasing_thresh; |
|
std::atomic<font color='#5555FF'><</font><font color='#0000FF'><u>bool</u></font><font color='#5555FF'>></font> updated_net_since_last_sync; |
|
|
|
<font color='#0000FF'><u>bool</u></font> sync_file_reloaded; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> previous_loss_values_dump_amount; |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> test_previous_loss_values_dump_amount; |
|
<b>}</b>; |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> net_type, |
|
<font color='#0000FF'>typename</font> solver_type |
|
<font color='#5555FF'>></font> |
|
std::ostream<font color='#5555FF'>&</font> <b><a name='operator'></a>operator</b><font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font> |
|
std::ostream<font color='#5555FF'>&</font> out, |
|
dnn_trainer<font color='#5555FF'><</font>net_type,solver_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> trainer |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>using</font> std::endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>dnn_trainer details: \n</font>"; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> net_type::num_layers: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> net_type::num_layers <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#009900'>// figure out how big the net is in MB. |
|
</font> std::ostringstream sout; |
|
net_type temp <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#009900'>// make a copy so that we can clean it without mutating the trainer's net. |
|
</font> temp.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>temp, sout<font face='Lucida Console'>)</font>; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> net size: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> sout.<font color='#BB00BB'>str</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font><font color='#979000'>1024.0</font><font color='#5555FF'>/</font><font color='#979000'>1024.0</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> MiB</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#009900'>// Don't include the loss params in the hash since we print them on the next line. |
|
</font> <font color='#009900'>// They also aren't really part of the "architecture" of the network. |
|
</font> out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> net architecture hash: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>md5</font><font face='Lucida Console'>(</font><font color='#BB00BB'>cast_to_string</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>loss_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> get_train_one_step_calls(): </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_train_one_step_calls</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> synchronization file: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_synchronization_file</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> trainer.get_solvers()[0]: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_solvers</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>[<font color='#979000'>0</font>] <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> mini batch size: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_mini_batch_size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>auto</font> sched <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>get_learning_rate_schedule</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>sched.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> using explicit user-supplied learning rate schedule</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> learning rate: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> learning rate shrink factor: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_learning_rate_shrink_factor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> min learning rate: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_min_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> iterations without progress threshold: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
out <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> test iterations without progress threshold: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> trainer.<font color='#BB00BB'>get_test_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
<font color='#0000FF'>return</font> out; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_DNn_TRAINER_H_ |
|
</font> |
|
|
|
</pre></body></html> |