File size: 52,313 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 |
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_imagenet_train_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
This program was used to train the resnet34_1000_imagenet_classifier.dnn
network used by the <a href="dnn_imagenet_ex.cpp.html">dnn_imagenet_ex.cpp</a> example program.
You should be familiar with dlib's DNN module before reading this example
program. So read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> first.
*/</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iterator<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>thread<font color='#5555FF'>></font>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font>
<font color='#0000FF'>using</font> residual <font color='#5555FF'>=</font> add_prev1<font color='#5555FF'><</font>block<font color='#5555FF'><</font>N,BN,<font color='#979000'>1</font>,tag1<font color='#5555FF'><</font>SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font>,<font color='#0000FF'><u>int</u></font>,<font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='block'></a>block</b>, <font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font><font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font>
<font color='#0000FF'>using</font> residual_down <font color='#5555FF'>=</font> add_prev2<font color='#5555FF'><</font>avg_pool<font color='#5555FF'><</font><font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,skip1<font color='#5555FF'><</font>tag2<font color='#5555FF'><</font>block<font color='#5555FF'><</font>N,BN,<font color='#979000'>2</font>,tag1<font color='#5555FF'><</font>SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font><font color='#5555FF'>></font> <font color='#0000FF'>class</font> <b><a name='BN'></a>BN</b>, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font>
<font color='#0000FF'>using</font> block <font color='#5555FF'>=</font> BN<font color='#5555FF'><</font>con<font color='#5555FF'><</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,relu<font color='#5555FF'><</font>BN<font color='#5555FF'><</font>con<font color='#5555FF'><</font>N,<font color='#979000'>3</font>,<font color='#979000'>3</font>,stride,stride,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> res <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual<font color='#5555FF'><</font>block,N,bn_con,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> ares <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual<font color='#5555FF'><</font>block,N,affine,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> res_down <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual_down<font color='#5555FF'><</font>block,N,bn_con,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'><u>int</u></font> N, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> ares_down <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>residual_down<font color='#5555FF'><</font>block,N,affine,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level1 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>512</font>,res<font color='#5555FF'><</font><font color='#979000'>512</font>,res_down<font color='#5555FF'><</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level2 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res<font color='#5555FF'><</font><font color='#979000'>256</font>,res_down<font color='#5555FF'><</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level3 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>128</font>,res<font color='#5555FF'><</font><font color='#979000'>128</font>,res<font color='#5555FF'><</font><font color='#979000'>128</font>,res_down<font color='#5555FF'><</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> level4 <font color='#5555FF'>=</font> res<font color='#5555FF'><</font><font color='#979000'>64</font>,res<font color='#5555FF'><</font><font color='#979000'>64</font>,res<font color='#5555FF'><</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel1 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>512</font>,ares<font color='#5555FF'><</font><font color='#979000'>512</font>,ares_down<font color='#5555FF'><</font><font color='#979000'>512</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel2 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares<font color='#5555FF'><</font><font color='#979000'>256</font>,ares_down<font color='#5555FF'><</font><font color='#979000'>256</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel3 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>128</font>,ares<font color='#5555FF'><</font><font color='#979000'>128</font>,ares<font color='#5555FF'><</font><font color='#979000'>128</font>,ares_down<font color='#5555FF'><</font><font color='#979000'>128</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> alevel4 <font color='#5555FF'>=</font> ares<font color='#5555FF'><</font><font color='#979000'>64</font>,ares<font color='#5555FF'><</font><font color='#979000'>64</font>,ares<font color='#5555FF'><</font><font color='#979000'>64</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#009900'>// training network type
</font><font color='#0000FF'>using</font> net_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'><</font>fc<font color='#5555FF'><</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'><</font>
level1<font color='#5555FF'><</font>
level2<font color='#5555FF'><</font>
level3<font color='#5555FF'><</font>
level4<font color='#5555FF'><</font>
max_pool<font color='#5555FF'><</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,
input_rgb_image_sized<font color='#5555FF'><</font><font color='#979000'>227</font><font color='#5555FF'>></font>
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#009900'>// testing network type (replaced batch normalization with fixed affine transforms)
</font><font color='#0000FF'>using</font> anet_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'><</font>fc<font color='#5555FF'><</font><font color='#979000'>1000</font>,avg_pool_everything<font color='#5555FF'><</font>
alevel1<font color='#5555FF'><</font>
alevel2<font color='#5555FF'><</font>
alevel3<font color='#5555FF'><</font>
alevel4<font color='#5555FF'><</font>
max_pool<font color='#5555FF'><</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,relu<font color='#5555FF'><</font>affine<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>64</font>,<font color='#979000'>7</font>,<font color='#979000'>7</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,
input_rgb_image_sized<font color='#5555FF'><</font><font color='#979000'>227</font><font color='#5555FF'>></font>
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
rectangle <b><a name='make_random_cropping_rect_resnet'></a>make_random_cropping_rect_resnet</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img,
dlib::rand<font color='#5555FF'>&</font> rnd
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// figure out what rectangle we want to crop from the image
</font> <font color='#0000FF'><u>double</u></font> mins <font color='#5555FF'>=</font> <font color='#979000'>0.466666666</font>, maxs <font color='#5555FF'>=</font> <font color='#979000'>0.875</font>;
<font color='#0000FF'>auto</font> scale <font color='#5555FF'>=</font> mins <font color='#5555FF'>+</font> rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font face='Lucida Console'>(</font>maxs<font color='#5555FF'>-</font>mins<font face='Lucida Console'>)</font>;
<font color='#0000FF'>auto</font> size <font color='#5555FF'>=</font> scale<font color='#5555FF'>*</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
rectangle <font color='#BB00BB'>rect</font><font face='Lucida Console'>(</font>size, size<font face='Lucida Console'>)</font>;
<font color='#009900'>// randomly shift the box around
</font> point <font color='#BB00BB'>offset</font><font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,
rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> <font color='#BB00BB'>move_rect</font><font face='Lucida Console'>(</font>rect, offset<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_image'></a>randomly_crop_image</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img,
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> crop,
dlib::rand<font color='#5555FF'>&</font> rnd
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
<font color='#009900'>// now crop it out as a 227x227 image.
</font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>img, <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, crop<font face='Lucida Console'>)</font>;
<font color='#009900'>// Also randomly flip the image
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
crop <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop<font face='Lucida Console'>)</font>;
<font color='#009900'>// And then randomly adjust the colors.
</font> <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>crop, rnd<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_images'></a>randomly_crop_images</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img,
dlib::array<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>&</font> crops,
dlib::rand<font color='#5555FF'>&</font> rnd,
<font color='#0000FF'><u>long</u></font> num_crops
<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'><</font>chip_details<font color='#5555FF'>></font> dets;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> num_crops; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect_resnet</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
dets.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>,<font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#BB00BB'>extract_image_chips</font><font face='Lucida Console'>(</font>img, dets, crops<font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font><font color='#5555FF'>&</font> img : crops<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Also randomly flip the image
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
img <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>;
<font color='#009900'>// And then randomly adjust the colors.
</font> <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>img, rnd<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>struct</font> <b><a name='image_info'></a>image_info</b>
<b>{</b>
string filename;
string label;
<font color='#0000FF'><u>long</u></font> numeric_label;
<b>}</b>;
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> <b><a name='get_imagenet_train_listing'></a>get_imagenet_train_listing</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> images_folder
<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> results;
image_info temp;
temp.numeric_label <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#009900'>// We will loop over all the label types in the dataset, each is contained in a subfolder.
</font> <font color='#0000FF'>auto</font> subdirs <font color='#5555FF'>=</font> <font color='#BB00BB'>directory</font><font face='Lucida Console'>(</font>images_folder<font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_dirs</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// But first, sort the sub directories so the numeric labels will be assigned in sorted order.
</font> std::<font color='#BB00BB'>sort</font><font face='Lucida Console'>(</font>subdirs.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, subdirs.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> subdir : subdirs<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Now get all the images in this label type
</font> temp.label <font color='#5555FF'>=</font> subdir.<font color='#BB00BB'>name</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> image_file : subdir.<font color='#BB00BB'>get_files</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
temp.filename <font color='#5555FF'>=</font> image_file;
results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label;
<b>}</b>
<font color='#0000FF'>return</font> results;
<b>}</b>
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> <b><a name='get_imagenet_val_listing'></a>get_imagenet_val_listing</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> imagenet_root_dir,
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> validation_images_file
<font face='Lucida Console'>)</font>
<b>{</b>
ifstream <font color='#BB00BB'>fin</font><font face='Lucida Console'>(</font>validation_images_file<font face='Lucida Console'>)</font>;
string label, filename;
std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font> results;
image_info temp;
temp.numeric_label <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>;
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>fin <font color='#5555FF'>></font><font color='#5555FF'>></font> label <font color='#5555FF'>></font><font color='#5555FF'>></font> filename<font face='Lucida Console'>)</font>
<b>{</b>
temp.filename <font color='#5555FF'>=</font> imagenet_root_dir<font color='#5555FF'>+</font>"<font color='#CC0000'>/</font>"<font color='#5555FF'>+</font>filename;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>file_exists</font><font face='Lucida Console'>(</font>temp.filename<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
cerr <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>file doesn't exist! </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> temp.filename <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#BB00BB'>exit</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>label <font color='#5555FF'>!</font><font color='#5555FF'>=</font> temp.label<font face='Lucida Console'>)</font>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>temp.numeric_label;
temp.label <font color='#5555FF'>=</font> label;
results.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>return</font> results;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>3</font><font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>To run this program you need a copy of the imagenet ILSVRC2015 dataset and</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>also the file http://dlib.net/files/imagenet2015_validation_images.txt.bz2</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>With those things, you call this program like this: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>./dnn_imagenet_train_ex /path/to/ILSVRC2015 imagenet2015_validation_images.txt</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#0000FF'>return</font> <font color='#979000'>1</font>;
<b>}</b>
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\nSCANNING IMAGENET DATASET\n</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_imagenet_train_listing</font><font face='Lucida Console'>(</font><font color='#BB00BB'>string</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font color='#5555FF'>+</font>"<font color='#CC0000'>/Data/CLS-LOC/train/</font>"<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>images in dataset: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> number_of_classes <font color='#5555FF'>=</font> listing.<font color='#BB00BB'>back</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.numeric_label<font color='#5555FF'>+</font><font color='#979000'>1</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> number_of_classes <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>1000</font><font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Didn't find the imagenet dataset. </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#0000FF'>return</font> <font color='#979000'>1</font>;
<b>}</b>
<font color='#BB00BB'>set_dnn_prefer_smallest_algorithms</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>;
net_type net;
dnn_trainer<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>net,<font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>;
trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>imagenet_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// This threshold is probably excessively large. You could likely get good results
</font> <font color='#009900'>// with a smaller value but if you aren't in a hurry this value will surely work well.
</font> trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>20000</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Since the progress threshold is so large might as well set the batch normalization
</font> <font color='#009900'>// stats window to something big too.
</font> <font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>net, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> samples;
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> labels;
<font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops. It's
</font> <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy. Using multiple
</font> <font color='#009900'>// thread for this kind of data preparation helps us do that. Each thread puts the
</font> <font color='#009900'>// crops into the data queue.
</font> dlib::pipe<font color='#5555FF'><</font>std::pair<font color='#5555FF'><</font>image_info,matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>data, <font color='#5555FF'>&</font>listing]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font>
<b>{</b>
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>seed<font face='Lucida Console'>)</font>;
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> img;
std::pair<font color='#5555FF'><</font>image_info, matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> temp;
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
temp.first <font color='#5555FF'>=</font> listing[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>];
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, temp.first.filename<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>randomly_crop_image</font><font face='Lucida Console'>(</font>img, temp.second, rnd<font face='Lucida Console'>)</font>;
data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>;
std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
<font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer.
</font> <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-3.
</font> <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> initial_learning_rate<font color='#5555FF'>*</font><font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// make a 160 image mini-batch
</font> std::pair<font color='#5555FF'><</font>image_info, matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> img;
<font color='#0000FF'>while</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> <font color='#979000'>160</font><font face='Lucida Console'>)</font>
<b>{</b>
data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>img<font face='Lucida Console'>)</font>;
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>img.second<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>img.first.numeric_label<font face='Lucida Console'>)</font>;
<b>}</b>
trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before
</font> <font color='#009900'>// moving on.
</font> data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// also wait for threaded processing to stop in the trainer.
</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>saving network</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>resnet34.dnn</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> net;
<font color='#009900'>// Now test the network on the imagenet validation dataset. First, make a testing
</font> <font color='#009900'>// network with softmax as the final layer. We don't have to do this if we just wanted
</font> <font color='#009900'>// to test the "top1 accuracy" since the normal network outputs the class prediction.
</font> <font color='#009900'>// But this snet object will make getting the top5 predictions easy as it directly
</font> <font color='#009900'>// outputs the probability of each class as its final output.
</font> softmax<font color='#5555FF'><</font>anet_type::subnet_type<font color='#5555FF'>></font> snet; snet.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Testing network on imagenet validation dataset...</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'><u>int</u></font> num_right_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'><u>int</u></font> num_wrong_top1 <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// loop over all the imagenet validation images
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font> l : <font color='#BB00BB'>get_imagenet_val_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
dlib::array<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> images;
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> img;
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>img, l.filename<font face='Lucida Console'>)</font>;
<font color='#009900'>// Grab 16 random crops from the image. We will run all of them through the
</font> <font color='#009900'>// network and average the results.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> num_crops <font color='#5555FF'>=</font> <font color='#979000'>16</font>;
<font color='#BB00BB'>randomly_crop_images</font><font face='Lucida Console'>(</font>img, images, rnd, num_crops<font face='Lucida Console'>)</font>;
<font color='#009900'>// p(i) == the probability the image contains object of class i.
</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font>,<font color='#979000'>1</font>,<font color='#979000'>1000</font><font color='#5555FF'>></font> p <font color='#5555FF'>=</font> <font color='#BB00BB'>sum_rows</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font><font color='#BB00BB'>snet</font><font face='Lucida Console'>(</font>images.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, images.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>num_crops;
<font color='#009900'>// check top 1 accuracy
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right_top1;
<font color='#0000FF'>else</font>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong_top1;
<font color='#009900'>// check top 5 accuracy
</font> <font color='#0000FF'><u>bool</u></font> found_match <font color='#5555FF'>=</font> <font color='#979000'>false</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; k <font color='#5555FF'><</font> <font color='#979000'>5</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>k<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'><u>long</u></font> predicted_label <font color='#5555FF'>=</font> <font color='#BB00BB'>index_of_max</font><font face='Lucida Console'>(</font>p<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>p</font><font face='Lucida Console'>(</font>predicted_label<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>predicted_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> l.numeric_label<font face='Lucida Console'>)</font>
<b>{</b>
found_match <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<font color='#0000FF'>break</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>found_match<font face='Lucida Console'>)</font>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right;
<font color='#0000FF'>else</font>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong;
<b>}</b>
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>val top5 accuracy: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right<font color='#5555FF'>+</font>num_wrong<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>val top1 accuracy: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right_top1<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right_top1<font color='#5555FF'>+</font>num_wrong_top1<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<b>}</b>
</pre></body></html> |