|
<html><head><title>dlib C++ Library - dnn_imagenet_train_ex.cpp</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt |
|
</font><font color='#009900'>/* |
|
This program was used to train the resnet34_1000_imagenet_classifier.dnn |
|
network used by the <a href="dnn_imagenet_ex.cpp.html">dnn_imagenet_ex.cpp</a> example program. |
|
|
|
You should be familiar with dlib's DNN module before reading this example |
|
program. So read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> first. |
|
*/</font> |
|
|
|
|
|
|
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iterator<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>thread<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std; |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib; |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'>using</font> residual <font color='#5555FF'>=</font> add_prev1<font color='#5555FF'><</font>block<font color='#5555FF'><</font>N,BN,<font color='#979000'>1</font>,tag1<font color='#5555FF'><</font>SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'>using</font> residual_down <font color='#5555FF'>=</font> add_prev2<font color='#5555FF'><</font>avg_pool<font color='#5555FF'><</font><font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,skip1<font color='#5555FF'><</font>tag2<font color='#5555FF'><</font>block<font color='#5555FF'><</font>N,BN,<font color='#979000'>2</font>,tag1<font color='#5555FF'><</font>SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'>using</font> block <font color='#5555FF'>=</font> BN<font color='#5555FF'><</font>con<font color='#5555FF'><</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,relu<font color='#5555FF'><</font>BN<font color='#5555FF'><</font>con<font color='#5555FF'><</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,stride,stride,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> res <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual<font color='#5555FF'><</font>block,N,bn_con,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> ares <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual<font color='#5555FF'><</font>block,N,affine,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> res_down <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual_down<font color='#5555FF'><</font>block,N,bn_con,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> ares_down <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual_down<font color='#5555FF'><</font>block,N,affine,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level1 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>512</font>,res<font color='#5555FF'><</font><font color='#979000'>512</font>,res_down<font color='#5555FF'><</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level2 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res_down<font color='#5555FF'><</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level3 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>128</font>,res<font color='#5555FF'><</font><font color='#979000'>128</font>,res<font color='#5555FF'><</font><font color='#979000'>128</font>,res_down<font color='#5555FF'><</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level4 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>64</font>,res<font color='#5555FF'><</font><font color='#979000'>64</font>,res<font color='#5555FF'><</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel1 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>512</font>,ares<font color='#5555FF'><</font><font color='#979000'>512</font>,ares_down<font color='#5555FF'><</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel2 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares_down<font color='#5555FF'><</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel3 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>128</font>,ares<font color='#5555FF'><</font><font color='#979000'>128</font>,ares<font color='#5555FF'><</font><font color='#979000'>128</font>,ares_down<font color='#5555FF'><</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel4 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>64</font>,ares<font color='#5555FF'><</font><font color='#979000'>64</font>,ares<font color='#5555FF'><</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// training network type |
|
</font><font color='#0000FF'>using</font> net_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'><</font>fc<font color='#5555FF'><</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'><</font> |
|
level1<font color='#5555FF'><</font> |
|
level2<font color='#5555FF'><</font> |
|
level3<font color='#5555FF'><</font> |
|
level4<font color='#5555FF'><</font> |
|
max_pool<font color='#5555FF'><</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>, |
|
input_rgb_image_sized<font color='#5555FF'><</font><font color='#979000'>227</font><font color='#5555FF'>></font> |
|
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// testing network type (replaced batch normalization with fixed affine transforms) |
|
</font><font color='#0000FF'>using</font> anet_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'><</font>fc<font color='#5555FF'><</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'><</font> |
|
alevel1<font color='#5555FF'><</font> |
|
alevel2<font color='#5555FF'><</font> |
|
alevel3<font color='#5555FF'><</font> |
|
alevel4<font color='#5555FF'><</font> |
|
max_pool<font color='#5555FF'><</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'><</font>affine<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>, |
|
input_rgb_image_sized<font color='#5555FF'><</font><font color='#979000'>227</font><font color='#5555FF'>></font> |
|
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
rectangle <b><a name='make_random_cropping_rect_resnet'></a>make_random_cropping_rect_resnet</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img, |
|
dlib::rand<font color='#5555FF'>&</font> rnd |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// figure out what rectangle we want to crop from the image |
|
</font> <font color='#0000FF'><u>double</u></font> mins <font color='#5555FF'>=</font> <font color='#979000'>0.466666666</font>, maxs <font color='#5555FF'>=</font> <font color='#979000'>0.875</font>; |
|
<font color='#0000FF'>auto</font> scale <font color='#5555FF'>=</font> mins <font color='#5555FF'>+</font> rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font face='Lucida Console'>(</font>maxs<font color='#5555FF'>-</font>mins<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>auto</font> size <font color='#5555FF'>=</font> scale<font color='#5555FF'>*</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
rectangle <font color='#BB00BB'>rect</font><font face='Lucida Console'>(</font>size, size<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// randomly shift the box around |
|
</font> point <font color='#BB00BB'>offset</font><font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, |
|
rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> <font color='#BB00BB'>move_rect</font><font face='Lucida Console'>(</font>rect, offset<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_image'></a>randomly_crop_image</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img, |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> crop, |
|
dlib::rand<font color='#5555FF'>&</font> rnd |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// now crop it out as a 227x227 image. |
|
</font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>img, <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, crop<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Also randomly flip the image |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font> |
|
crop <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// And then randomly adjust the colors. |
|
</font> <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>crop, rnd<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_images'></a>randomly_crop_images</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img, |
|
dlib::array<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>&</font> crops, |
|
dlib::rand<font color='#5555FF'>&</font> rnd, |
|
<font color='#0000FF'><u>long</u></font> num_crops |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>chip_details<font color='#5555FF'>></font> dets; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> num_crops; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>; |
|
dets.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>extract_image_chips</font><font face='Lucida Console'>(</font>img, dets, crops<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> img : crops<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Also randomly flip the image |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font> |
|
img <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// And then randomly adjust the colors. |
|
</font> <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>struct</font> <b><a name='image_info'></a>image_info</b> |
|
<b>{</b> |
|
string filename; |
|
string label; |
|
<font color='#0000FF'><u>long</u></font> numeric_label; |
|
<b>}</b>; |
|
|
|
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> <b><a name='get_imagenet_train_listing'></a>get_imagenet_train_listing</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> images_folder |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> results; |
|
image_info temp; |
|
temp.numeric_label <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#009900'>// We will loop over all the label types in the dataset, each is contained in a subfolder. |
|
</font> <font color='#0000FF'>auto</font> subdirs <font color='#5555FF'>=</font> <font color='#BB00BB'>directory</font><font face='Lucida Console'>(</font>images_folder<font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_dirs</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// But first, sort the sub directories so the numeric labels will be assigned in sorted order. |
|
</font> std::<font color='#BB00BB'>sort</font><font face='Lucida Console'>(</font>subdirs.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, subdirs.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> subdir : subdirs<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Now get all the images in this label type |
|
</font> temp.label <font color='#5555FF'>=</font> subdir.<font color='#BB00BB'>name</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> image_file : subdir.<font color='#BB00BB'>get_files</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
temp.filename <font color='#5555FF'>=</font> image_file; |
|
results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label; |
|
<b>}</b> |
|
<font color='#0000FF'>return</font> results; |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> <b><a name='get_imagenet_val_listing'></a>get_imagenet_val_listing</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> imagenet_root_dir, |
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> validation_images_file |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font>validation_images_file<font face='Lucida Console'>)</font>; |
|
string label, filename; |
|
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> results; |
|
image_info temp; |
|
temp.numeric_label <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>; |
|
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>fin <font color='#5555FF'>></font><font color='#5555FF'>></font> label <font color='#5555FF'>></font><font color='#5555FF'>></font> filename<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
temp.filename <font color='#5555FF'>=</font> imagenet_root_dir<font color='#5555FF'>+</font>"<font color='#CC0000'>/</font>"<font color='#5555FF'>+</font>filename; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>file_exists</font><font face='Lucida Console'>(</font>temp.filename<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cerr <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>file doesn't exist! </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> temp.filename <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#BB00BB'>exit</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>label <font color='#5555FF'>!</font><font color='#5555FF'>=</font> temp.label<font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label; |
|
|
|
temp.label <font color='#5555FF'>=</font> label; |
|
results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>return</font> results; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>3</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>To run this program you need a copy of the imagenet ILSVRC2015 dataset and</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>also the file http://dlib.net/files/imagenet2015_validation_images.txt.bz2</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>With those things, you call this program like this: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>./dnn_imagenet_train_ex /path/to/ILSVRC2015 imagenet2015_validation_images.txt</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\nSCANNING IMAGENET DATASET\n</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_imagenet_train_listing</font><font face='Lucida Console'>(</font><font color='#BB00BB'>string</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font color='#5555FF'>+</font>"<font color='#CC0000'>/Data/CLS-LOC/train/</font>"<font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>images in dataset: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> number_of_classes <font color='#5555FF'>=</font> listing.<font color='#BB00BB'>back</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.numeric_label<font color='#5555FF'>+</font><font color='#979000'>1</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> number_of_classes <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1000</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Didn't find the imagenet dataset. </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
<font color='#BB00BB'>set_dnn_prefer_smallest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>; |
|
|
|
net_type net; |
|
dnn_trainer<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>net,<font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>imagenet_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// This threshold is probably excessively large. You could likely get good results |
|
</font> <font color='#009900'>// with a smaller value but if you aren't in a hurry this value will surely work well. |
|
</font> trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>20000</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Since the progress threshold is so large might as well set the batch normalization |
|
</font> <font color='#009900'>// stats window to something big too. |
|
</font> <font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>net, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>; |
|
|
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> samples; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> labels; |
|
|
|
<font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops. It's |
|
</font> <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy. Using multiple |
|
</font> <font color='#009900'>// thread for this kind of data preparation helps us do that. Each thread puts the |
|
</font> <font color='#009900'>// crops into the data queue. |
|
</font> dlib::pipe<font color='#5555FF'><</font>std::pair<font color='#5555FF'><</font>image_info,matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>data, <font color='#5555FF'>&</font>listing]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>seed<font face='Lucida Console'>)</font>; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> img; |
|
std::pair<font color='#5555FF'><</font>image_info, matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> temp; |
|
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
temp.first <font color='#5555FF'>=</font> listing[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>]; |
|
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, temp.first.filename<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>randomly_crop_image</font><font face='Lucida Console'>(</font>img, temp.second, rnd<font face='Lucida Console'>)</font>; |
|
data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b>; |
|
std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer. |
|
</font> <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-3. |
|
</font> <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> initial_learning_rate<font color='#5555FF'>*</font><font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>3</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// make a 160 image mini-batch |
|
</font> std::pair<font color='#5555FF'><</font>image_info, matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> img; |
|
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> <font color='#979000'>160</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>; |
|
|
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>img.second<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>img.first.numeric_label<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before |
|
</font> <font color='#009900'>// moving on. |
|
</font> data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// also wait for threaded processing to stop in the trainer. |
|
</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>saving network</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>resnet34.dnn</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> net; |
|
|
|
|
|
|
|
|
|
|
|
|
|
<font color='#009900'>// Now test the network on the imagenet validation dataset. First, make a testing |
|
</font> <font color='#009900'>// network with softmax as the final layer. We don't have to do this if we just wanted |
|
</font> <font color='#009900'>// to test the "top1 accuracy" since the normal network outputs the class prediction. |
|
</font> <font color='#009900'>// But this snet object will make getting the top5 predictions easy as it directly |
|
</font> <font color='#009900'>// outputs the probability of each class as its final output. |
|
</font> softmax<font color='#5555FF'><</font>anet_type::subnet_type<font color='#5555FF'>></font> snet; snet.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Testing network on imagenet validation dataset...</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'><u>int</u></font> num_right_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'><u>int</u></font> num_wrong_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// loop over all the imagenet validation images |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> l : <font color='#BB00BB'>get_imagenet_val_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
dlib::array<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> images; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> img; |
|
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, l.filename<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Grab 16 random crops from the image. We will run all of them through the |
|
</font> <font color='#009900'>// network and average the results. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> num_crops <font color='#5555FF'>=</font> <font color='#979000'>16</font>; |
|
<font color='#BB00BB'>randomly_crop_images</font><font face='Lucida Console'>(</font>img, images, rnd, num_crops<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// p(i) == the probability the image contains object of class i. |
|
</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font>,<font color='#979000'>1</font>,<font color='#979000'>1000</font><font color='#5555FF'>></font> p <font color='#5555FF'>=</font> <font color='#BB00BB'>sum_rows</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font><font color='#BB00BB'>snet</font><font face='Lucida Console'>(</font>images.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, images.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>num_crops; |
|
|
|
<font color='#009900'>// check top 1 accuracy |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right_top1; |
|
<font color='#0000FF'>else</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong_top1; |
|
|
|
<font color='#009900'>// check top 5 accuracy |
|
</font> <font color='#0000FF'><u>bool</u></font> found_match <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; k <font color='#5555FF'><</font> <font color='#979000'>5</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>k<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>long</u></font> predicted_label <font color='#5555FF'>=</font> <font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>p</font><font face='Lucida Console'>(</font>predicted_label<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>predicted_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
found_match <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<font color='#0000FF'>break</font>; |
|
<b>}</b> |
|
|
|
<b>}</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>found_match<font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right; |
|
<font color='#0000FF'>else</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong; |
|
<b>}</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>val top5 accuracy: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right<font color='#5555FF'>+</font>num_wrong<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>val top1 accuracy: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right_top1<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right_top1<font color='#5555FF'>+</font>num_wrong_top1<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
|
|
|
|
</pre></body></html> |