|
<html><head><title>dlib C++ Library - dnn_inception_ex.cpp</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt |
|
</font><font color='#009900'>/* |
|
This is an example illustrating the use of the deep learning tools from the |
|
dlib C++ Library. I'm assuming you have already read the introductory |
|
<a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> examples. In this |
|
example we are going to show how to create inception networks. |
|
|
|
An inception network is composed of inception blocks of the form: |
|
|
|
input from SUBNET |
|
/ | \ |
|
/ | \ |
|
block1 block2 ... blockN |
|
\ | / |
|
\ | / |
|
concatenate tensors from blocks |
|
| |
|
output |
|
|
|
That is, an inception block runs a number of smaller networks (e.g. block1, |
|
block2) and then concatenates their results. For further reading refer to: |
|
Szegedy, Christian, et al. "Going deeper with convolutions." Proceedings of |
|
the IEEE Conference on Computer Vision and Pattern Recognition. 2015. |
|
*/</font> |
|
|
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std; |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib; |
|
|
|
<font color='#009900'>// Inception layer has some different convolutions inside. Here we define |
|
</font><font color='#009900'>// blocks as convolutions with different kernel size that we will use in |
|
</font><font color='#009900'>// inception layer block. |
|
</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_a1 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>10</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_a2 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>10</font>,<font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>16</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_a3 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>10</font>,<font color='#979000'>5</font>,<font color='#979000'>5</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>16</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_a4 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>10</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,max_pool<font color='#5555FF'><</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// Here is inception layer definition. It uses different blocks to process input |
|
</font><font color='#009900'>// and returns combined output. Dlib includes a number of these inceptionN |
|
</font><font color='#009900'>// layer types which are themselves created using concat layers. |
|
</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> incept_a <font color='#5555FF'>=</font> inception4<font color='#5555FF'><</font>block_a1,block_a2,block_a3,block_a4, SUBNET<font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// Network can have inception layers of different structure. It will work |
|
</font><font color='#009900'>// properly so long as all the sub-blocks inside a particular inception block |
|
</font><font color='#009900'>// output tensors with the same number of rows and columns. |
|
</font><font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_b1 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>4</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_b2 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>4</font>,<font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> block_b3 <font color='#5555FF'>=</font> relu<font color='#5555FF'><</font>con<font color='#5555FF'><</font><font color='#979000'>4</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,max_pool<font color='#5555FF'><</font><font color='#979000'>3</font>,<font color='#979000'>3</font>,<font color='#979000'>1</font>,<font color='#979000'>1</font>,SUBNET<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> <font color='#0000FF'>using</font> incept_b <font color='#5555FF'>=</font> inception3<font color='#5555FF'><</font>block_b1,block_b2,block_b3,SUBNET<font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// Now we can define a simple network for classifying MNIST digits. We will |
|
</font><font color='#009900'>// train and test this network in the code below. |
|
</font><font color='#0000FF'>using</font> net_type <font color='#5555FF'>=</font> loss_multiclass_log<font color='#5555FF'><</font> |
|
fc<font color='#5555FF'><</font><font color='#979000'>10</font>, |
|
relu<font color='#5555FF'><</font>fc<font color='#5555FF'><</font><font color='#979000'>32</font>, |
|
max_pool<font color='#5555FF'><</font><font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,incept_b<font color='#5555FF'><</font> |
|
max_pool<font color='#5555FF'><</font><font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,<font color='#979000'>2</font>,incept_a<font color='#5555FF'><</font> |
|
input<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> |
|
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> |
|
<b>{</b> |
|
<font color='#009900'>// This example is going to run on the MNIST dataset. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>This example needs the MNIST dataset to run!</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>You can get MNIST from http://yann.lecun.com/exdb/mnist/</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Download the 4 files that comprise the dataset, decompress them, and</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>put them in a folder. Then give that folder as input to this program.</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
|
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> training_images; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> training_labels; |
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> testing_images; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> testing_labels; |
|
<font color='#BB00BB'>load_mnist_dataset</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], training_images, training_labels, testing_images, testing_labels<font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#009900'>// Make an instance of our inception network. |
|
</font> net_type net; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>The net has </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> net.num_layers <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> layers in it.</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> net <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Training NN...</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
dnn_trainer<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font><font color='#979000'>0.01</font><font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_min_learning_rate</font><font face='Lucida Console'>(</font><font color='#979000'>0.00001</font><font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_mini_batch_size</font><font face='Lucida Console'>(</font><font color='#979000'>128</font><font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>inception_sync</font>", std::chrono::<font color='#BB00BB'>seconds</font><font face='Lucida Console'>(</font><font color='#979000'>20</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Train the network. This might take a few minutes... |
|
</font> trainer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>training_images, training_labels<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// At this point our net object should have learned how to classify MNIST images. But |
|
</font> <font color='#009900'>// before we try it out let's save it to disk. Note that, since the trainer has been |
|
</font> <font color='#009900'>// running images through the network, net will have a bunch of state in it related to |
|
</font> <font color='#009900'>// the last batch of images it processed (e.g. outputs from each layer). Since we |
|
</font> <font color='#009900'>// don't care about saving that kind of stuff to disk we can tell the network to forget |
|
</font> <font color='#009900'>// about that kind of transient data so that our file will be smaller. We do this by |
|
</font> <font color='#009900'>// "cleaning" the network before saving it. |
|
</font> net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>mnist_network_inception.dat</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> net; |
|
<font color='#009900'>// Now if we later wanted to recall the network from disk we can simply say: |
|
</font> <font color='#009900'>// deserialize("mnist_network_inception.dat") >> net; |
|
</font> |
|
|
|
<font color='#009900'>// Now let's run the training images through the network. This statement runs all the |
|
</font> <font color='#009900'>// images through it and asks the loss layer to convert the network's raw output into |
|
</font> <font color='#009900'>// labels. In our case, these labels are the numbers between 0 and 9. |
|
</font> std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> predicted_labels <font color='#5555FF'>=</font> <font color='#BB00BB'>net</font><font face='Lucida Console'>(</font>training_images<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#009900'>// And then let's see if it classified them correctly. |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> training_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>predicted_labels[i] <font color='#5555FF'>=</font><font color='#5555FF'>=</font> training_labels[i]<font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right; |
|
<font color='#0000FF'>else</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong; |
|
|
|
<b>}</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>training num_right: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>training num_wrong: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_wrong <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>training accuracy: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right<font color='#5555FF'>+</font>num_wrong<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#009900'>// Let's also see if the network can correctly classify the testing images. |
|
</font> <font color='#009900'>// Since MNIST is an easy dataset, we should see 99% accuracy. |
|
</font> predicted_labels <font color='#5555FF'>=</font> <font color='#BB00BB'>net</font><font face='Lucida Console'>(</font>testing_images<font face='Lucida Console'>)</font>; |
|
num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> testing_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>predicted_labels[i] <font color='#5555FF'>=</font><font color='#5555FF'>=</font> testing_labels[i]<font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right; |
|
<font color='#0000FF'>else</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong; |
|
|
|
<b>}</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>testing num_right: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>testing num_wrong: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_wrong <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>testing accuracy: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> num_right<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font>num_right<font color='#5555FF'>+</font>num_wrong<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
|
|
|
|
</pre></body></html> |