Aging_MouthReplace / dlibs /docs /dnn_dcgan_train_ex.cpp.html
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
36.3 kB
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_dcgan_train_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
This is an example illustrating the use of the deep learning tools from the dlib C++
Library. I'm assuming you have already read the <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a>, the
<a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> and the <a href="dnn_introduction3_ex.cpp.html">dnn_introduction3_ex.cpp</a> examples. In this example
program we are going to show how one can train Generative Adversarial Networks (GANs). In
particular, we will train a Deep Convolutional Generative Adversarial Network (DCGAN) like
the one introduced in this paper:
"Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks"
by Alec Radford, Luke Metz, Soumith Chintala.
The main idea is that there are two neural networks training at the same time:
- the generator is in charge of generating images that look as close as possible to the
ones from the dataset.
- the discriminator will decide whether an image is fake (created by the generator) or real
(selected from the dataset).
Each training iteration alternates between training the discriminator and the generator.
We first train the discriminator with real and fake images and then use the gradient from
the discriminator to update the generator.
In this example, we are going to learn how to generate digits from the MNIST dataset, but
the same code can be run using the Fashion MNIST datset:
https://github.com/zalandoresearch/fashion-mnist
*/</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>algorithm<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>gui_widgets.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>matrix.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;
<font color='#009900'>// Some helper definitions for the noise generation
</font><font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> noise_size <font color='#5555FF'>=</font> <font color='#979000'>100</font>;
<font color='#0000FF'>using</font> noise_t <font color='#5555FF'>=</font> std::array<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font>, <font color='#979000'>1</font>, <font color='#979000'>1</font><font color='#5555FF'>&gt;</font>, noise_size<font color='#5555FF'>&gt;</font>;
noise_t <b><a name='make_noise'></a>make_noise</b><font face='Lucida Console'>(</font>dlib::rand<font color='#5555FF'>&amp;</font> rnd<font face='Lucida Console'>)</font>
<b>{</b>
noise_t noise;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> n : noise<font face='Lucida Console'>)</font>
<b>{</b>
n <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_gaussian</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>return</font> noise;
<b>}</b>
<font color='#009900'>// A convolution with custom padding
</font><font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>long</u></font> num_filters, <font color='#0000FF'><u>long</u></font> kernel_size, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'><u>int</u></font> padding, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>using</font> conp <font color='#5555FF'>=</font> add_layer<font color='#5555FF'>&lt;</font>con_<font color='#5555FF'>&lt;</font>num_filters, kernel_size, kernel_size, stride, stride, padding, padding<font color='#5555FF'>&gt;</font>, SUBNET<font color='#5555FF'>&gt;</font>;
<font color='#009900'>// A transposed convolution to with custom padding
</font><font color='#0000FF'>template</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>long</u></font> num_filters, <font color='#0000FF'><u>long</u></font> kernel_size, <font color='#0000FF'><u>int</u></font> stride, <font color='#0000FF'><u>int</u></font> padding, <font color='#0000FF'>typename</font> SUBNET<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>using</font> contp <font color='#5555FF'>=</font> add_layer<font color='#5555FF'>&lt;</font>cont_<font color='#5555FF'>&lt;</font>num_filters, kernel_size, kernel_size, stride, stride, padding, padding<font color='#5555FF'>&gt;</font>, SUBNET<font color='#5555FF'>&gt;</font>;
<font color='#009900'>// The generator is made of a bunch of deconvolutional layers. Its input is a 1 x 1 x k noise
</font><font color='#009900'>// tensor, and the output is the generated image. The loss layer does not matter for the
</font><font color='#009900'>// training, we just stack a compatible one on top to be able to have a () operator on the
</font><font color='#009900'>// generator.
</font><font color='#0000FF'>using</font> generator_type <font color='#5555FF'>=</font>
loss_binary_log_per_pixel<font color='#5555FF'>&lt;</font>
sig<font color='#5555FF'>&lt;</font>contp<font color='#5555FF'>&lt;</font><font color='#979000'>1</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>,
relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>contp<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>,
relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>contp<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>, <font color='#979000'>3</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>,
relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>contp<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>, <font color='#979000'>4</font>, <font color='#979000'>1</font>, <font color='#979000'>0</font>,
input<font color='#5555FF'>&lt;</font>noise_t<font color='#5555FF'>&gt;</font>
<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#009900'>// Now, let's proceed to define the discriminator, whose role will be to decide whether an
</font><font color='#009900'>// image is fake or not.
</font><font color='#0000FF'>using</font> discriminator_type <font color='#5555FF'>=</font>
loss_binary_log<font color='#5555FF'>&lt;</font>
conp<font color='#5555FF'>&lt;</font><font color='#979000'>1</font>, <font color='#979000'>3</font>, <font color='#979000'>1</font>, <font color='#979000'>0</font>,
leaky_relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>conp<font color='#5555FF'>&lt;</font><font color='#979000'>256</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>,
leaky_relu<font color='#5555FF'>&lt;</font>bn_con<font color='#5555FF'>&lt;</font>conp<font color='#5555FF'>&lt;</font><font color='#979000'>128</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>,
leaky_relu<font color='#5555FF'>&lt;</font>conp<font color='#5555FF'>&lt;</font><font color='#979000'>64</font>, <font color='#979000'>4</font>, <font color='#979000'>2</font>, <font color='#979000'>1</font>,
input<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>
<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font>;
<font color='#009900'>// Some helper functions to generate and get the images from the generator
</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font> <b><a name='generate_image'></a>generate_image</b><font face='Lucida Console'>(</font>generator_type<font color='#5555FF'>&amp;</font> net, <font color='#0000FF'>const</font> noise_t<font color='#5555FF'>&amp;</font> noise<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> output <font color='#5555FF'>=</font> <font color='#BB00BB'>net</font><font face='Lucida Console'>(</font>noise<font face='Lucida Console'>)</font>;
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font> image;
<font color='#BB00BB'>assign_image</font><font face='Lucida Console'>(</font>image, <font color='#979000'>255</font> <font color='#5555FF'>*</font> output<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> image;
<b>}</b>
std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <b><a name='get_generated_images'></a>get_generated_images</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> out<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> images;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> n <font color='#5555FF'>=</font> <font color='#979000'>0</font>; n <font color='#5555FF'>&lt;</font> out.<font color='#BB00BB'>num_samples</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>n<font face='Lucida Console'>)</font>
<b>{</b>
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> output <font color='#5555FF'>=</font> <font color='#BB00BB'>image_plane</font><font face='Lucida Console'>(</font>out, n<font face='Lucida Console'>)</font>;
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font> image;
<font color='#BB00BB'>assign_image</font><font face='Lucida Console'>(</font>image, <font color='#979000'>255</font> <font color='#5555FF'>*</font> output<font face='Lucida Console'>)</font>;
images.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>return</font> images;
<b>}</b>
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
<font color='#009900'>// This example is going to run on the MNIST dataset.
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>This example needs the MNIST dataset to run!</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>You can get MNIST from http://yann.lecun.com/exdb/mnist/</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Download the 4 files that comprise the dataset, decompress them, and</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>put them in a folder. Then give that folder as input to this program.</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>return</font> EXIT_FAILURE;
<b>}</b>
<font color='#009900'>// MNIST is broken into two parts, a training set of 60000 images and a test set of 10000
</font> <font color='#009900'>// images. Each image is labeled so that we know what hand written digit is depicted.
</font> <font color='#009900'>// These next statements load the dataset into memory.
</font> std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> training_images;
std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> training_labels;
std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> testing_images;
std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> testing_labels;
<font color='#BB00BB'>load_mnist_dataset</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>], training_images, training_labels, testing_images, testing_labels<font face='Lucida Console'>)</font>;
<font color='#009900'>// Fix the random generator seeds for network initialization and noise
</font> <font color='#BB00BB'>srand</font><font face='Lucida Console'>(</font><font color='#979000'>1234</font><font face='Lucida Console'>)</font>;
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>rand</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Instantiate both generator and discriminator
</font> generator_type generator;
discriminator_type discriminator;
<font color='#009900'>// setup all leaky_relu_ layers in the discriminator to have alpha = 0.2
</font> <font color='#BB00BB'>visit_computational_layers</font><font face='Lucida Console'>(</font>discriminator, []<font face='Lucida Console'>(</font>leaky_relu_<font color='#5555FF'>&amp;</font> l<font face='Lucida Console'>)</font><b>{</b> l <font color='#5555FF'>=</font> <font color='#BB00BB'>leaky_relu_</font><font face='Lucida Console'>(</font><font color='#979000'>0.2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
<font color='#009900'>// Remove the bias learning from all bn_ inputs in both networks
</font> <font color='#BB00BB'>disable_duplicative_biases</font><font face='Lucida Console'>(</font>generator<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>disable_duplicative_biases</font><font face='Lucida Console'>(</font>discriminator<font face='Lucida Console'>)</font>;
<font color='#009900'>// Forward random noise so that we see the tensor size at each layer
</font> <font color='#BB00BB'>discriminator</font><font face='Lucida Console'>(</font><font color='#BB00BB'>generate_image</font><font face='Lucida Console'>(</font>generator, <font color='#BB00BB'>make_noise</font><font face='Lucida Console'>(</font>rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>generator (</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>count_parameters</font><font face='Lucida Console'>(</font>generator<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> parameters)</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> generator <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>discriminator (</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>count_parameters</font><font face='Lucida Console'>(</font>discriminator<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> parameters)</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> discriminator <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#009900'>// The solvers for the generator and discriminator networks. In this example, we are going to
</font> <font color='#009900'>// train the networks manually, so we don't need to use a dnn_trainer. Note that the
</font> <font color='#009900'>// discriminator could be trained using a dnn_trainer, but not the generator, since its
</font> <font color='#009900'>// training process is a bit particular.
</font> std::vector<font color='#5555FF'>&lt;</font>adam<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>g_solvers</font><font face='Lucida Console'>(</font>generator.num_computational_layers, <font color='#BB00BB'>adam</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>, <font color='#979000'>0.5</font>, <font color='#979000'>0.999</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'>&lt;</font>adam<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>d_solvers</font><font face='Lucida Console'>(</font>discriminator.num_computational_layers, <font color='#BB00BB'>adam</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>, <font color='#979000'>0.5</font>, <font color='#979000'>0.999</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'><u>double</u></font> learning_rate <font color='#5555FF'>=</font> <font color='#979000'>2e</font><font color='#5555FF'>-</font><font color='#979000'>4</font>;
<font color='#009900'>// Resume training from last sync file
</font> <font color='#0000FF'><u>size_t</u></font> iteration <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>file_exists</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_sync</font>"<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_sync</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> generator <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> discriminator <font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> iteration;
<b>}</b>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>size_t</u></font> minibatch_size <font color='#5555FF'>=</font> <font color='#979000'>64</font>;
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>real_labels</font><font face='Lucida Console'>(</font>minibatch_size, <font color='#979000'>1</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>fake_labels</font><font face='Lucida Console'>(</font>minibatch_size, <font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
dlib::image_window win;
resizable_tensor real_samples_tensor, fake_samples_tensor, noises_tensor;
running_stats<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> g_loss, d_loss;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>iteration <font color='#5555FF'>&lt;</font> <font color='#979000'>50000</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Train the discriminator with real images
</font> std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>char</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> real_samples;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>real_samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> minibatch_size<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font> idx <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> training_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
real_samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>training_images[idx]<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
</font> discriminator.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>real_samples.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, real_samples.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, real_samples_tensor<font face='Lucida Console'>)</font>;
d_loss.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>discriminator.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>real_samples_tensor, real_labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
discriminator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>real_samples_tensor<font face='Lucida Console'>)</font>;
discriminator.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>d_solvers, learning_rate<font face='Lucida Console'>)</font>;
<font color='#009900'>// Train the discriminator with fake images
</font> <font color='#009900'>// 1. Generate some random noise
</font> std::vector<font color='#5555FF'>&lt;</font>noise_t<font color='#5555FF'>&gt;</font> noises;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>noises.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> minibatch_size<font face='Lucida Console'>)</font>
<b>{</b>
noises.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#BB00BB'>make_noise</font><font face='Lucida Console'>(</font>rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// 2. Convert noises into a tensor
</font> generator.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>noises.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, noises.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, noises_tensor<font face='Lucida Console'>)</font>;
<font color='#009900'>// 3. Forward the noise through the network and convert the outputs into images.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> fake_samples <font color='#5555FF'>=</font> <font color='#BB00BB'>get_generated_images</font><font face='Lucida Console'>(</font>generator.<font color='#BB00BB'>forward</font><font face='Lucida Console'>(</font>noises_tensor<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// 4. Finally train the discriminator. The following lines are equivalent to calling
</font> <font color='#009900'>// train_one_step(fake_samples, fake_labels)
</font> discriminator.<font color='#BB00BB'>to_tensor</font><font face='Lucida Console'>(</font>fake_samples.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, fake_samples.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, fake_samples_tensor<font face='Lucida Console'>)</font>;
d_loss.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>discriminator.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>fake_samples_tensor, fake_labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
discriminator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>fake_samples_tensor<font face='Lucida Console'>)</font>;
discriminator.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>d_solvers, learning_rate<font face='Lucida Console'>)</font>;
<font color='#009900'>// Train the generator
</font> <font color='#009900'>// This part is the essence of the Generative Adversarial Networks. Until now, we have
</font> <font color='#009900'>// just trained a binary classifier that the generator is not aware of. But now, the
</font> <font color='#009900'>// discriminator is going to give feedback to the generator on how it should update
</font> <font color='#009900'>// itself to generate more realistic images. The following lines perform the same
</font> <font color='#009900'>// actions as train_one_step() except for the network update part. They can also be
</font> <font color='#009900'>// seen as test_one_step() plus the error back propagation.
</font>
<font color='#009900'>// Forward the fake samples and compute the loss with real labels
</font> g_loss.<font color='#BB00BB'>add</font><font face='Lucida Console'>(</font>discriminator.<font color='#BB00BB'>compute_loss</font><font face='Lucida Console'>(</font>fake_samples_tensor, real_labels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Back propagate the error to fill the final data gradient
</font> discriminator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>fake_samples_tensor<font face='Lucida Console'>)</font>;
<font color='#009900'>// Get the gradient that will tell the generator how to update itself
</font> <font color='#0000FF'>const</font> tensor<font color='#5555FF'>&amp;</font> d_grad <font color='#5555FF'>=</font> discriminator.<font color='#BB00BB'>get_final_data_gradient</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
generator.<font color='#BB00BB'>back_propagate_error</font><font face='Lucida Console'>(</font>noises_tensor, d_grad<font face='Lucida Console'>)</font>;
generator.<font color='#BB00BB'>update_parameters</font><font face='Lucida Console'>(</font>g_solvers, learning_rate<font face='Lucida Console'>)</font>;
<font color='#009900'>// At some point, we should see that the generated images start looking like samples from
</font> <font color='#009900'>// the MNIST dataset
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>+</font><font color='#5555FF'>+</font>iteration <font color='#5555FF'>%</font> <font color='#979000'>1000</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_sync</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> generator <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> discriminator <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> iteration;
std::cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>step#: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> iteration <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\tdiscriminator loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> d_loss.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>*</font> <font color='#979000'>2</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>
"<font color='#CC0000'>\tgenerator loss: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> g_loss.<font color='#BB00BB'>mean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> '<font color='#FF0000'>\n</font>';
win.<font color='#BB00BB'>set_image</font><font face='Lucida Console'>(</font><font color='#BB00BB'>tile_images</font><font face='Lucida Console'>(</font>fake_samples<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
win.<font color='#BB00BB'>set_title</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>DCGAN step#: </font>" <font color='#5555FF'>+</font> <font color='#BB00BB'>to_string</font><font face='Lucida Console'>(</font>iteration<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
d_loss.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
g_loss.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#009900'>// Once the training has finished, we don't need the discriminator any more. We just keep the
</font> <font color='#009900'>// generator.
</font> generator.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>dcgan_mnist.dnn</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> generator;
<font color='#009900'>// To test the generator, we just forward some random noise through it and visualize the
</font> <font color='#009900'>// output.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>win.<font color='#BB00BB'>is_closed</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> image <font color='#5555FF'>=</font> <font color='#BB00BB'>generate_image</font><font face='Lucida Console'>(</font>generator, <font color='#BB00BB'>make_noise</font><font face='Lucida Console'>(</font>rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> real <font color='#5555FF'>=</font> <font color='#BB00BB'>discriminator</font><font face='Lucida Console'>(</font>image<font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0</font>;
win.<font color='#BB00BB'>set_image</font><font face='Lucida Console'>(</font>image<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>The discriminator thinks it's </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font face='Lucida Console'>(</font>real ? "<font color='#CC0000'>real</font>" : "<font color='#CC0000'>fake</font>"<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>. Hit enter to generate a new image</font>";
cin.<font color='#BB00BB'>get</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>return</font> EXIT_SUCCESS;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>exception<font color='#5555FF'>&amp;</font> e<font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>return</font> EXIT_FAILURE;
<b>}</b>
</pre></body></html>