|
<html><head><title>dlib C++ Library - dnn_dcgan_train_ex.cpp</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt |
|
</font><font color='#009900'>/* |
|
This is an example illustrating the use of the deep learning tools from the dlib C++ |
|
Library. I'm assuming you have already read the <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a>, the |
|
<a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> and the <a href="dnn_introduction3_ex.cpp.html">dnn_introduction3_ex.cpp</a> examples. In this example |
|
program we are going to show how one can train Generative Adversarial Networks (GANs). In |
|
particular, we will train a Deep Convolutional Generative Adversarial Network (DCGAN) like |
|
the one introduced in this paper: |
|
"Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks" |
|
by Alec Radford, Luke Metz, Soumith Chintala. |
|
|
|
The main idea is that there are two neural networks training at the same time: |
|
- the generator is in charge of generating images that look as close as possible to the |
|
ones from the dataset. |
|
- the discriminator will decide whether an image is fake (created by the generator) or real |
|
(selected from the dataset). |
|
|
|
Each training iteration alternates between training the discriminator and the generator. |
|
We first train the discriminator with real and fake images and then use the gradient from |
|
the discriminator to update the generator. |
|
|
|
In this example, we are going to learn how to generate digits from the MNIST dataset, but |
|
the same code can be run using the Fashion MNIST datset: |
|
https://github.com/zalandoresearch/fashion-mnist |
|
*/</font> |
|
|
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>algorithm<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>gui_widgets.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>matrix.h<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std; |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib; |
|
|
|
<font color='#009900'>// Some helper definitions for the noise generation |
|
</font><font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> noise_size <font color='#5555FF'>=</font> <font color='#979000'>100</font>; |
|
<font color='#0000FF'>using</font> noise_t <font color='#5555FF'>=</font> std::array<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font>, <font color='#979000'>1</font>, <font color='#979000'>1</font><font color='#5555FF'>></font>, noise_size<font color='#5555FF'>></font>; |
|
|
|
noise_t <b><a name='make_noise'></a>make_noise</b><font face='Lucida Console'>(</font>dlib::rand<font color='#5555FF'>&</font> rnd<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
noise_t noise; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font> n : noise<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
n <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_gaussian</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>return</font> noise; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// A convolution with custom padding |
|
</font><font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'><u>long</u></font> num_filters, <font color='#0000FF'><u>long</u></font> kernel_size, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'><u>int</u></font> padding, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'>using</font> conp <font color='#5555FF'>=</font> add_layer<font color='#5555FF'><</font>con_<font color='#5555FF'><</font>num_filters, kernel_size, kernel_size, stride, stride, padding, padding<font color='#5555FF'>></font>, SUBNET<font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// A transposed convolution to with custom padding |
|
</font><font color='#0000FF'>template</font><font color='#5555FF'><</font><font color='#0000FF'><u>long</u></font> num_filters, <font color='#0000FF'><u>long</u></font> kernel_size, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'><u>int</u></font> padding, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>></font> |
|
<font color='#0000FF'>using</font> contp <font color='#5555FF'>=</font> add_layer<font color='#5555FF'><</font>cont_<font color='#5555FF'><</font>num_filters, kernel_size, kernel_size, stride, stride, padding, padding<font color='#5555FF'>></font>, SUBNET<font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// The generator is made of a bunch of deconvolutional layers. Its input is a 1 x 1 x k noise |
|
</font><font color='#009900'>// tensor, and the output is the generated image. The loss layer does not matter for the |
|
</font><font color='#009900'>// training, we just stack a compatible one on top to be able to have a () operator on the |
|
</font><font color='#009900'>// generator. |
|
</font><font color='#0000FF'>using</font> generator_type <font color='#5555FF'>=</font> |
|
loss_binary_log_per_pixel<font color='#5555FF'><</font> |
|
sig<font color='#5555FF'><</font>contp<font color='#5555FF'><</font><font color='#979000'>1</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>, |
|
relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>contp<font color='#5555FF'><</font><font color='#979000'>64</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>, |
|
relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>contp<font color='#5555FF'><</font><font color='#979000'>128</font>, <font color='#979000'>3</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>, |
|
relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>contp<font color='#5555FF'><</font><font color='#979000'>256</font>, <font color='#979000'>4</font>, <font color='#979000'>1</font>, <font color='#979000'>0</font>, |
|
input<font color='#5555FF'><</font>noise_t<font color='#5555FF'>></font> |
|
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// Now, let's proceed to define the discriminator, whose role will be to decide whether an |
|
</font><font color='#009900'>// image is fake or not. |
|
</font><font color='#0000FF'>using</font> discriminator_type <font color='#5555FF'>=</font> |
|
loss_binary_log<font color='#5555FF'><</font> |
|
conp<font color='#5555FF'><</font><font color='#979000'>1</font>, <font color='#979000'>3</font>, <font color='#979000'>1</font>, <font color='#979000'>0</font>, |
|
leaky_relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>conp<font color='#5555FF'><</font><font color='#979000'>256</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>, |
|
leaky_relu<font color='#5555FF'><</font>bn_con<font color='#5555FF'><</font>conp<font color='#5555FF'><</font><font color='#979000'>128</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>, |
|
leaky_relu<font color='#5555FF'><</font>conp<font color='#5555FF'><</font><font color='#979000'>64</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>, |
|
input<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> |
|
<font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
|
|
<font color='#009900'>// Some helper functions to generate and get the images from the generator |
|
</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font> <b><a name='generate_image'></a>generate_image</b><font face='Lucida Console'>(</font>generator_type<font color='#5555FF'>&</font> net, <font color='#0000FF'>const</font> noise_t<font color='#5555FF'>&</font> noise<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> output <font color='#5555FF'>=</font> <font color='#BB00BB'>net</font><font face='Lucida Console'>(</font>noise<font face='Lucida Console'>)</font>; |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font> image; |
|
<font color='#BB00BB'>assign_image</font><font face='Lucida Console'>(</font>image, <font color='#979000'>255</font> <font color='#5555FF'>*</font> output<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>return</font> image; |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> <b><a name='get_generated_images'></a>get_generated_images</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> images; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> n <font color='#5555FF'>=</font> <font color='#979000'>0</font>; n <font color='#5555FF'><</font> out.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>n<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> output <font color='#5555FF'>=</font> <font color='#BB00BB'>image_plane</font><font face='Lucida Console'>(</font>out, n<font face='Lucida Console'>)</font>; |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font> image; |
|
<font color='#BB00BB'>assign_image</font><font face='Lucida Console'>(</font>image, <font color='#979000'>255</font> <font color='#5555FF'>*</font> output<font face='Lucida Console'>)</font>; |
|
images.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>return</font> images; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> |
|
<b>{</b> |
|
<font color='#009900'>// This example is going to run on the MNIST dataset. |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>This example needs the MNIST dataset to run!</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>You can get MNIST from http://yann.lecun.com/exdb/mnist/</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Download the 4 files that comprise the dataset, decompress them, and</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>put them in a folder. Then give that folder as input to this program.</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> EXIT_FAILURE; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// MNIST is broken into two parts, a training set of 60000 images and a test set of 10000 |
|
</font> <font color='#009900'>// images. Each image is labeled so that we know what hand written digit is depicted. |
|
</font> <font color='#009900'>// These next statements load the dataset into memory. |
|
</font> std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> training_images; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> training_labels; |
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> testing_images; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> testing_labels; |
|
<font color='#BB00BB'>load_mnist_dataset</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], training_images, training_labels, testing_images, testing_labels<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Fix the random generator seeds for network initialization and noise |
|
</font> <font color='#BB00BB'>srand</font><font face='Lucida Console'>(</font><font color='#979000'>1234</font><font face='Lucida Console'>)</font>; |
|
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>rand</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Instantiate both generator and discriminator |
|
</font> generator_type generator; |
|
discriminator_type discriminator; |
|
<font color='#009900'>// setup all leaky_relu_ layers in the discriminator to have alpha = 0.2 |
|
</font> <font color='#BB00BB'>visit_computational_layers</font><font face='Lucida Console'>(</font>discriminator, []<font face='Lucida Console'>(</font>leaky_relu_<font color='#5555FF'>&</font> l<font face='Lucida Console'>)</font><b>{</b> l <font color='#5555FF'>=</font> <font color='#BB00BB'>leaky_relu_</font><font face='Lucida Console'>(</font><font color='#979000'>0.2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Remove the bias learning from all bn_ inputs in both networks |
|
</font> <font color='#BB00BB'>disable_duplicative_biases</font><font face='Lucida Console'>(</font>generator<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>disable_duplicative_biases</font><font face='Lucida Console'>(</font>discriminator<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Forward random noise so that we see the tensor size at each layer |
|
</font> <font color='#BB00BB'>discriminator</font><font face='Lucida Console'>(</font><font color='#BB00BB'>generate_image</font><font face='Lucida Console'>(</font>generator, <font color='#BB00BB'>make_noise</font><font face='Lucida Console'>(</font>rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>generator (</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>count_parameters</font><font face='Lucida Console'>(</font>generator<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> parameters)</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> generator <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>discriminator (</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>count_parameters</font><font face='Lucida Console'>(</font>discriminator<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> parameters)</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> discriminator <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#009900'>// The solvers for the generator and discriminator networks. In this example, we are going to |
|
</font> <font color='#009900'>// train the networks manually, so we don't need to use a dnn_trainer. Note that the |
|
</font> <font color='#009900'>// discriminator could be trained using a dnn_trainer, but not the generator, since its |
|
</font> <font color='#009900'>// training process is a bit particular. |
|
</font> std::vector<font color='#5555FF'><</font>adam<font color='#5555FF'>></font> <font color='#BB00BB'>g_solvers</font><font face='Lucida Console'>(</font>generator.num_computational_layers, <font color='#BB00BB'>adam</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>, <font color='#979000'>0.5</font>, <font color='#979000'>0.999</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
std::vector<font color='#5555FF'><</font>adam<font color='#5555FF'>></font> <font color='#BB00BB'>d_solvers</font><font face='Lucida Console'>(</font>discriminator.num_computational_layers, <font color='#BB00BB'>adam</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>, <font color='#979000'>0.5</font>, <font color='#979000'>0.999</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'><u>double</u></font> learning_rate <font color='#5555FF'>=</font> <font color='#979000'>2e</font><font color='#5555FF'>-</font><font color='#979000'>4</font>; |
|
|
|
<font color='#009900'>// Resume training from last sync file |
|
</font> <font color='#0000FF'><u>size_t</u></font> iteration <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>file_exists</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_sync</font>"<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_sync</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>></font> generator <font color='#5555FF'>></font><font color='#5555FF'>></font> discriminator <font color='#5555FF'>></font><font color='#5555FF'>></font> iteration; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> minibatch_size <font color='#5555FF'>=</font> <font color='#979000'>64</font>; |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> <font color='#BB00BB'>real_labels</font><font face='Lucida Console'>(</font>minibatch_size, <font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> <font color='#BB00BB'>fake_labels</font><font face='Lucida Console'>(</font>minibatch_size, <font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
dlib::image_window win; |
|
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor; |
|
running_stats<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font> g_loss, d_loss; |
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>iteration <font color='#5555FF'><</font> <font color='#979000'>50000</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Train the discriminator with real images |
|
</font> std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> real_samples; |
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>real_samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> minibatch_size<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font> idx <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> training_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
real_samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>training_images[idx]<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#009900'>// The following lines are equivalent to calling train_one_step(real_samples, real_labels) |
|
</font> discriminator.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>real_samples.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, real_samples.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, real_samples_tensor<font face='Lucida Console'>)</font>; |
|
d_loss.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>discriminator.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>real_samples_tensor, real_labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
discriminator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>real_samples_tensor<font face='Lucida Console'>)</font>; |
|
discriminator.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>d_solvers, learning_rate<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Train the discriminator with fake images |
|
</font> <font color='#009900'>// 1. Generate some random noise |
|
</font> std::vector<font color='#5555FF'><</font>noise_t<font color='#5555FF'>></font> noises; |
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>noises.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> minibatch_size<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
noises.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>make_noise</font><font face='Lucida Console'>(</font>rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#009900'>// 2. Convert noises into a tensor |
|
</font> generator.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>noises.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, noises.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, noises_tensor<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// 3. Forward the noise through the network and convert the outputs into images. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> fake_samples <font color='#5555FF'>=</font> <font color='#BB00BB'>get_generated_images</font><font face='Lucida Console'>(</font>generator.<font color='#BB00BB'>forward</font><font face='Lucida Console'>(</font>noises_tensor<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// 4. Finally train the discriminator. The following lines are equivalent to calling |
|
</font> <font color='#009900'>// train_one_step(fake_samples, fake_labels) |
|
</font> discriminator.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>fake_samples.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, fake_samples.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, fake_samples_tensor<font face='Lucida Console'>)</font>; |
|
d_loss.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>discriminator.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>fake_samples_tensor, fake_labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
discriminator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>fake_samples_tensor<font face='Lucida Console'>)</font>; |
|
discriminator.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>d_solvers, learning_rate<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Train the generator |
|
</font> <font color='#009900'>// This part is the essence of the Generative Adversarial Networks. Until now, we have |
|
</font> <font color='#009900'>// just trained a binary classifier that the generator is not aware of. But now, the |
|
</font> <font color='#009900'>// discriminator is going to give feedback to the generator on how it should update |
|
</font> <font color='#009900'>// itself to generate more realistic images. The following lines perform the same |
|
</font> <font color='#009900'>// actions as train_one_step() except for the network update part. They can also be |
|
</font> <font color='#009900'>// seen as test_one_step() plus the error back propagation. |
|
</font> |
|
<font color='#009900'>// Forward the fake samples and compute the loss with real labels |
|
</font> g_loss.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>discriminator.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>fake_samples_tensor, real_labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Back propagate the error to fill the final data gradient |
|
</font> discriminator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>fake_samples_tensor<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>// Get the gradient that will tell the generator how to update itself |
|
</font> <font color='#0000FF'>const</font> tensor<font color='#5555FF'>&</font> d_grad <font color='#5555FF'>=</font> discriminator.<font color='#BB00BB'>get_final_data_gradient</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
generator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>noises_tensor, d_grad<font face='Lucida Console'>)</font>; |
|
generator.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>g_solvers, learning_rate<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// At some point, we should see that the generated images start looking like samples from |
|
</font> <font color='#009900'>// the MNIST dataset |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>+</font><font color='#5555FF'>+</font>iteration <font color='#5555FF'>%</font> <font color='#979000'>1000</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_sync</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> generator <font color='#5555FF'><</font><font color='#5555FF'><</font> discriminator <font color='#5555FF'><</font><font color='#5555FF'><</font> iteration; |
|
std::cout <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>step#: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> iteration <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\tdiscriminator loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> d_loss.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>*</font> <font color='#979000'>2</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> |
|
"<font color='#CC0000'>\tgenerator loss: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> g_loss.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> '<font color='#FF0000'>\n</font>'; |
|
win.<font color='#BB00BB'>set_image</font><font face='Lucida Console'>(</font><font color='#BB00BB'>tile_images</font><font face='Lucida Console'>(</font>fake_samples<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
win.<font color='#BB00BB'>set_title</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>DCGAN step#: </font>" <font color='#5555FF'>+</font> <font color='#BB00BB'>to_string</font><font face='Lucida Console'>(</font>iteration<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
d_loss.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
g_loss.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Once the training has finished, we don't need the discriminator any more. We just keep the |
|
</font> <font color='#009900'>// generator. |
|
</font> generator.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_mnist.dnn</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> generator; |
|
|
|
<font color='#009900'>// To test the generator, we just forward some random noise through it and visualize the |
|
</font> <font color='#009900'>// output. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>win.<font color='#BB00BB'>is_closed</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> image <font color='#5555FF'>=</font> <font color='#BB00BB'>generate_image</font><font face='Lucida Console'>(</font>generator, <font color='#BB00BB'>make_noise</font><font face='Lucida Console'>(</font>rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> real <font color='#5555FF'>=</font> <font color='#BB00BB'>discriminator</font><font face='Lucida Console'>(</font>image<font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0</font>; |
|
win.<font color='#BB00BB'>set_image</font><font face='Lucida Console'>(</font>image<font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>The discriminator thinks it's </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font face='Lucida Console'>(</font>real ? "<font color='#CC0000'>real</font>" : "<font color='#CC0000'>fake</font>"<font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>. Hit enter to generate a new image</font>"; |
|
cin.<font color='#BB00BB'>get</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>return</font> EXIT_SUCCESS; |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> EXIT_FAILURE; |
|
<b>}</b> |
|
|
|
</pre></body></html> |