<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_semantic_segmentation_train_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
    This example shows how to train a semantic segmentation net using the PASCAL VOC2012
    dataset.  For an introduction to what segmentation is, see the accompanying header file
    dnn_semantic_segmentation_ex.h.

    Instructions how to run the example:
    1. Download the PASCAL VOC2012 data, and untar it somewhere.
       http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    2. Build the dnn_semantic_segmentation_train_ex example program.
    3. Run:
       ./dnn_semantic_segmentation_train_ex /path/to/VOC2012
    4. Wait while the network is being trained.
    5. Build the dnn_semantic_segmentation_ex example program.
    6. Run:
       ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images

    It would be a good idea to become familiar with dlib's DNN tooling before reading this
    example.  So you should read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a>
    before reading this example program.
*/</font>

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='dnn_semantic_segmentation_ex.h.html'>dnn_semantic_segmentation_ex.h</a>"

<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iterator<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>thread<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;

<font color='#009900'>// A single training sample. A mini-batch comprises many of these.
</font><font color='#0000FF'>struct</font> <b><a name='training_sample'></a>training_sample</b>
<b>{</b>
    matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
    matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font> label_image; <font color='#009900'>// The ground-truth label of each pixel.
</font><b>}</b>;

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
rectangle <b><a name='make_random_cropping_rect'></a>make_random_cropping_rect</b><font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> img,
    dlib::rand<font color='#5555FF'>&amp;</font> rnd
<font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#009900'>// figure out what rectangle we want to crop from the image
</font>    <font color='#0000FF'><u>double</u></font> mins <font color='#5555FF'>=</font> <font color='#979000'>0.466666666</font>, maxs <font color='#5555FF'>=</font> <font color='#979000'>0.875</font>;
    <font color='#0000FF'>auto</font> scale <font color='#5555FF'>=</font> mins <font color='#5555FF'>+</font> rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font face='Lucida Console'>(</font>maxs<font color='#5555FF'>-</font>mins<font face='Lucida Console'>)</font>;
    <font color='#0000FF'>auto</font> size <font color='#5555FF'>=</font> scale<font color='#5555FF'>*</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    rectangle <font color='#BB00BB'>rect</font><font face='Lucida Console'>(</font>size, size<font face='Lucida Console'>)</font>;
    <font color='#009900'>// randomly shift the box around
</font>    point <font color='#BB00BB'>offset</font><font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,
                 rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'>return</font> <font color='#BB00BB'>move_rect</font><font face='Lucida Console'>(</font>rect, offset<font face='Lucida Console'>)</font>;
<b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_image'></a>randomly_crop_image</b> <font face='Lucida Console'>(</font>
    <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> input_image,
    <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> label_image,
    training_sample<font color='#5555FF'>&amp;</font> crop,
    dlib::rand<font color='#5555FF'>&amp;</font> rnd
<font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect</font><font face='Lucida Console'>(</font>input_image, rnd<font face='Lucida Console'>)</font>;

    <font color='#0000FF'>const</font> chip_details <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>, <font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// Crop the input image.
</font>    <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>input_image, chip_details, crop.input_image, <font color='#BB00BB'>interpolate_bilinear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// Crop the labels correspondingly. However, note that here bilinear
</font>    <font color='#009900'>// interpolation would make absolutely no sense - you wouldn't say that
</font>    <font color='#009900'>// a bicycle is half-way between an aeroplane and a bird, would you?
</font>    <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>label_image, chip_details, crop.label_image, <font color='#BB00BB'>interpolate_nearest_neighbor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// Also randomly flip the input image and the labels.
</font>    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font>
    <b>{</b>
        crop.input_image <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop.input_image<font face='Lucida Console'>)</font>;
        crop.label_image <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop.label_image<font face='Lucida Console'>)</font>;
    <b>}</b>

    <font color='#009900'>// And then randomly adjust the colors.
</font>    <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>crop.input_image, rnd<font face='Lucida Console'>)</font>;
<b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#009900'>// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
</font><font color='#0000FF'><u>double</u></font> <b><a name='calculate_accuracy'></a>calculate_accuracy</b><font face='Lucida Console'>(</font>anet_type<font color='#5555FF'>&amp;</font> anet, <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> dataset<font face='Lucida Console'>)</font>
<b>{</b>
    <font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
    <font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>;

    matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
    matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> rgb_label_image;
    matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font> index_label_image;
    matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font> net_output;

    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> image_info : dataset<font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#009900'>// Load the input image.
</font>        <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, image_info.image_filename<font face='Lucida Console'>)</font>;

        <font color='#009900'>// Load the ground-truth (RGB) labels.
</font>        <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>rgb_label_image, image_info.class_label_filename<font face='Lucida Console'>)</font>;

        <font color='#009900'>// Create predictions for each pixel. At this point, the type of each prediction
</font>        <font color='#009900'>// is an index (a value between 0 and 20). Note that the net may return an image
</font>        <font color='#009900'>// that is not exactly the same size as the input.
</font>        <font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font> temp <font color='#5555FF'>=</font> <font color='#BB00BB'>anet</font><font face='Lucida Console'>(</font>input_image<font face='Lucida Console'>)</font>;

        <font color='#009900'>// Convert the RGB values to indexes.
</font>        <font color='#BB00BB'>rgb_label_image_to_index_label_image</font><font face='Lucida Console'>(</font>rgb_label_image, index_label_image<font face='Lucida Console'>)</font>;

        <font color='#009900'>// Crop the net output to be exactly the same size as the input.
</font>        <font color='#0000FF'>const</font> chip_details <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>
            <font color='#BB00BB'>centered_rect</font><font face='Lucida Console'>(</font>temp.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>2</font>, temp.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>2</font>, input_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, input_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>,
            <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font>input_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, input_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
        <font face='Lucida Console'>)</font>;
        <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>temp, chip_details, net_output, <font color='#BB00BB'>interpolate_nearest_neighbor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

        <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> nr <font color='#5555FF'>=</font> index_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
        <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> nc <font color='#5555FF'>=</font> index_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

        <font color='#009900'>// Compare the predicted values to the ground-truth values.
</font>        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'>&lt;</font> nr; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'>&lt;</font> nc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font>
            <b>{</b>
                <font color='#0000FF'>const</font> uint16_t truth <font color='#5555FF'>=</font> <font color='#BB00BB'>index_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>;
                <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>truth <font color='#5555FF'>!</font><font color='#5555FF'>=</font> dlib::loss_multiclass_log_per_pixel_::label_to_ignore<font face='Lucida Console'>)</font>
                <b>{</b>
                    <font color='#0000FF'>const</font> uint16_t prediction <font color='#5555FF'>=</font> <font color='#BB00BB'>net_output</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>;
                    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>prediction <font color='#5555FF'>=</font><font color='#5555FF'>=</font> truth<font face='Lucida Console'>)</font>
                    <b>{</b>
                        <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right;
                    <b>}</b>
                    <font color='#0000FF'>else</font>
                    <b>{</b>
                        <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong;
                    <b>}</b>
                <b>}</b>
            <b>}</b>
        <b>}</b>
    <b>}</b>

    <font color='#009900'>// Return the accuracy estimate.
</font>    <font color='#0000FF'>return</font> num_right <font color='#5555FF'>/</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>num_right <font color='#5555FF'>+</font> num_wrong<font face='Lucida Console'>)</font>;
<b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>&lt;</font> <font color='#979000'>2</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> argc <font color='#5555FF'>&gt;</font> <font color='#979000'>3</font><font face='Lucida Console'>)</font>
    <b>{</b>
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>To run this program you need a copy of the PASCAL VOC2012 dataset.</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>You call this program like this: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
    <b>}</b>

    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\nSCANNING PASCAL VOC2012 DATASET\n</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;

    <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_pascal_voc2012_train_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>images in dataset: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
    <b>{</b>
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Didn't find the VOC2012 dataset. </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
        <font color='#0000FF'>return</font> <font color='#979000'>1</font>;
    <b>}</b>

    <font color='#009900'>// a mini-batch smaller than the default can be used with GPUs having less memory
</font>    <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> minibatch_size <font color='#5555FF'>=</font> argc <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>3</font> ? std::<font color='#BB00BB'>stoi</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font> : <font color='#979000'>23</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>mini-batch size: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> minibatch_size <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;

    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>;
    <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>;

    bnet_type bnet;
    dnn_trainer<font color='#5555FF'>&lt;</font>bnet_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>bnet,<font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>;
    trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>pascal_voc2012_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
    <font color='#009900'>// This threshold is probably excessively large.
</font>    trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>5000</font><font face='Lucida Console'>)</font>;
    <font color='#009900'>// Since the progress threshold is so large might as well set the batch normalization
</font>    <font color='#009900'>// stats window to something big too.
</font>    <font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>bnet, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// Output training parameters.
</font>    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> trainer <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;

    std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> samples;
    std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> labels;

    <font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops.  It's
</font>    <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy.  Using multiple
</font>    <font color='#009900'>// thread for this kind of data preparation helps us do that.  Each thread puts the
</font>    <font color='#009900'>// crops into the data queue.
</font>    dlib::pipe<font color='#5555FF'>&lt;</font>training_sample<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>data, <font color='#5555FF'>&amp;</font>listing]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font>
    <b>{</b>
        dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>seed<font face='Lucida Console'>)</font>;
        matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
        matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> rgb_label_image;
        matrix<font color='#5555FF'>&lt;</font>uint16_t<font color='#5555FF'>&gt;</font> index_label_image;
        training_sample temp;
        <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#009900'>// Pick a random input image.
</font>            <font color='#0000FF'>const</font> image_info<font color='#5555FF'>&amp;</font> image_info <font color='#5555FF'>=</font> listing[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>];

            <font color='#009900'>// Load the input image.
</font>            <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, image_info.image_filename<font face='Lucida Console'>)</font>;

            <font color='#009900'>// Load the ground-truth (RGB) labels.
</font>            <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>rgb_label_image, image_info.class_label_filename<font face='Lucida Console'>)</font>;

            <font color='#009900'>// Convert the RGB values to indexes.
</font>            <font color='#BB00BB'>rgb_label_image_to_index_label_image</font><font face='Lucida Console'>(</font>rgb_label_image, index_label_image<font face='Lucida Console'>)</font>;

            <font color='#009900'>// Randomly pick a part of the image.
</font>            <font color='#BB00BB'>randomly_crop_image</font><font face='Lucida Console'>(</font>input_image, index_label_image, temp, rnd<font face='Lucida Console'>)</font>;

            <font color='#009900'>// Push the result to be used by the trainer.
</font>            data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
        <b>}</b>
    <b>}</b>;
    std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
    std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
    std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
    std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;

    <font color='#009900'>// The main training loop.  Keep making mini-batches and giving them to the trainer.
</font>    <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-4.
</font>    <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>
    <b>{</b>
        samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
        labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

        <font color='#009900'>// make a mini-batch
</font>        training_sample temp;
        <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> minibatch_size<font face='Lucida Console'>)</font>
        <b>{</b>
            data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;

            samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.input_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
            labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.label_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
        <b>}</b>

        trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
    <b>}</b>

    <font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before
</font>    <font color='#009900'>// moving on.
</font>    data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

    <font color='#009900'>// also wait for threaded processing to stop in the trainer.
</font>    trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

    bnet.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>saving network</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>semantic_segmentation_net_filename<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> bnet;


    <font color='#009900'>// Make a copy of the network to use it for inference.
</font>    anet_type anet <font color='#5555FF'>=</font> bnet;

    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Testing the network...</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;

    <font color='#009900'>// Find the accuracy of the newly trained network on both the training and the validation sets.
</font>    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>train accuracy  :  </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>calculate_accuracy</font><font face='Lucida Console'>(</font>anet, <font color='#BB00BB'>get_pascal_voc2012_train_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>val accuracy    :  </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>calculate_accuracy</font><font face='Lucida Console'>(</font>anet, <font color='#BB00BB'>get_pascal_voc2012_val_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&amp;</font> e<font face='Lucida Console'>)</font>
<b>{</b>
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>


</pre></body></html>