<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_semantic_segmentation_train_ex.cpp</title></head><body bgcolor='white'><pre> <font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt </font><font color='#009900'>/* This example shows how to train a semantic segmentation net using the PASCAL VOC2012 dataset. For an introduction to what segmentation is, see the accompanying header file dnn_semantic_segmentation_ex.h. Instructions how to run the example: 1. Download the PASCAL VOC2012 data, and untar it somewhere. http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 2. Build the dnn_semantic_segmentation_train_ex example program. 3. Run: ./dnn_semantic_segmentation_train_ex /path/to/VOC2012 4. Wait while the network is being trained. 5. Build the dnn_semantic_segmentation_ex example program. 6. Run: ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images It would be a good idea to become familiar with dlib's DNN tooling before reading this example. So you should read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a> before reading this example program. */</font> <font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='dnn_semantic_segmentation_ex.h.html'>dnn_semantic_segmentation_ex.h</a>" <font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> <font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>></font> <font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>></font> <font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>></font> <font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iterator<font color='#5555FF'>></font> <font color='#0000FF'>#include</font> <font color='#5555FF'><</font>thread<font color='#5555FF'>></font> <font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std; <font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib; <font color='#009900'>// A single training sample. A mini-batch comprises many of these. </font><font color='#0000FF'>struct</font> <b><a name='training_sample'></a>training_sample</b> <b>{</b> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font> label_image; <font color='#009900'>// The ground-truth label of each pixel. </font><b>}</b>; <font color='#009900'>// ---------------------------------------------------------------------------------------- </font> rectangle <b><a name='make_random_cropping_rect'></a>make_random_cropping_rect</b><font face='Lucida Console'>(</font> <font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> img, dlib::rand<font color='#5555FF'>&</font> rnd <font face='Lucida Console'>)</font> <b>{</b> <font color='#009900'>// figure out what rectangle we want to crop from the image </font> <font color='#0000FF'><u>double</u></font> mins <font color='#5555FF'>=</font> <font color='#979000'>0.466666666</font>, maxs <font color='#5555FF'>=</font> <font color='#979000'>0.875</font>; <font color='#0000FF'>auto</font> scale <font color='#5555FF'>=</font> mins <font color='#5555FF'>+</font> rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>*</font><font face='Lucida Console'>(</font>maxs<font color='#5555FF'>-</font>mins<font face='Lucida Console'>)</font>; <font color='#0000FF'>auto</font> size <font color='#5555FF'>=</font> scale<font color='#5555FF'>*</font>std::<font color='#BB00BB'>min</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; rectangle <font color='#BB00BB'>rect</font><font face='Lucida Console'>(</font>size, size<font face='Lucida Console'>)</font>; <font color='#009900'>// randomly shift the box around </font> point <font color='#BB00BB'>offset</font><font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font><font face='Lucida Console'>(</font>img.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#0000FF'>return</font> <font color='#BB00BB'>move_rect</font><font face='Lucida Console'>(</font>rect, offset<font face='Lucida Console'>)</font>; <b>}</b> <font color='#009900'>// ---------------------------------------------------------------------------------------- </font> <font color='#0000FF'><u>void</u></font> <b><a name='randomly_crop_image'></a>randomly_crop_image</b> <font face='Lucida Console'>(</font> <font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> input_image, <font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font><font color='#5555FF'>&</font> label_image, training_sample<font color='#5555FF'>&</font> crop, dlib::rand<font color='#5555FF'>&</font> rnd <font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> rect <font color='#5555FF'>=</font> <font color='#BB00BB'>make_random_cropping_rect</font><font face='Lucida Console'>(</font>input_image, rnd<font face='Lucida Console'>)</font>; <font color='#0000FF'>const</font> chip_details <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>227</font>, <font color='#979000'>227</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// Crop the input image. </font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>input_image, chip_details, crop.input_image, <font color='#BB00BB'>interpolate_bilinear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// Crop the labels correspondingly. However, note that here bilinear </font> <font color='#009900'>// interpolation would make absolutely no sense - you wouldn't say that </font> <font color='#009900'>// a bicycle is half-way between an aeroplane and a bird, would you? </font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>label_image, chip_details, crop.label_image, <font color='#BB00BB'>interpolate_nearest_neighbor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// Also randomly flip the input image and the labels. </font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rnd.<font color='#BB00BB'>get_random_double</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font> <font color='#979000'>0.5</font><font face='Lucida Console'>)</font> <b>{</b> crop.input_image <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop.input_image<font face='Lucida Console'>)</font>; crop.label_image <font color='#5555FF'>=</font> <font color='#BB00BB'>fliplr</font><font face='Lucida Console'>(</font>crop.label_image<font face='Lucida Console'>)</font>; <b>}</b> <font color='#009900'>// And then randomly adjust the colors. </font> <font color='#BB00BB'>apply_random_color_offset</font><font face='Lucida Console'>(</font>crop.input_image, rnd<font face='Lucida Console'>)</font>; <b>}</b> <font color='#009900'>// ---------------------------------------------------------------------------------------- </font> <font color='#009900'>// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter. </font><font color='#0000FF'><u>double</u></font> <b><a name='calculate_accuracy'></a>calculate_accuracy</b><font face='Lucida Console'>(</font>anet_type<font color='#5555FF'>&</font> anet, <font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font><font color='#5555FF'>&</font> dataset<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>; <font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>; matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> rgb_label_image; matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font> index_label_image; matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font> net_output; <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> image_info : dataset<font face='Lucida Console'>)</font> <b>{</b> <font color='#009900'>// Load the input image. </font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, image_info.image_filename<font face='Lucida Console'>)</font>; <font color='#009900'>// Load the ground-truth (RGB) labels. </font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>rgb_label_image, image_info.class_label_filename<font face='Lucida Console'>)</font>; <font color='#009900'>// Create predictions for each pixel. At this point, the type of each prediction </font> <font color='#009900'>// is an index (a value between 0 and 20). Note that the net may return an image </font> <font color='#009900'>// that is not exactly the same size as the input. </font> <font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font> temp <font color='#5555FF'>=</font> <font color='#BB00BB'>anet</font><font face='Lucida Console'>(</font>input_image<font face='Lucida Console'>)</font>; <font color='#009900'>// Convert the RGB values to indexes. </font> <font color='#BB00BB'>rgb_label_image_to_index_label_image</font><font face='Lucida Console'>(</font>rgb_label_image, index_label_image<font face='Lucida Console'>)</font>; <font color='#009900'>// Crop the net output to be exactly the same size as the input. </font> <font color='#0000FF'>const</font> chip_details <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font> <font color='#BB00BB'>centered_rect</font><font face='Lucida Console'>(</font>temp.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>2</font>, temp.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>2</font>, input_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, input_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font>input_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, input_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>)</font>; <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>temp, chip_details, net_output, <font color='#BB00BB'>interpolate_nearest_neighbor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> nr <font color='#5555FF'>=</font> index_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> nc <font color='#5555FF'>=</font> index_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#009900'>// Compare the predicted values to the ground-truth values. </font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'><</font> nr; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'><</font> nc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>const</font> uint16_t truth <font color='#5555FF'>=</font> <font color='#BB00BB'>index_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>; <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>truth <font color='#5555FF'>!</font><font color='#5555FF'>=</font> dlib::loss_multiclass_log_per_pixel_::label_to_ignore<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>const</font> uint16_t prediction <font color='#5555FF'>=</font> <font color='#BB00BB'>net_output</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>; <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>prediction <font color='#5555FF'>=</font><font color='#5555FF'>=</font> truth<font face='Lucida Console'>)</font> <b>{</b> <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right; <b>}</b> <font color='#0000FF'>else</font> <b>{</b> <font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong; <b>}</b> <b>}</b> <b>}</b> <b>}</b> <b>}</b> <font color='#009900'>// Return the accuracy estimate. </font> <font color='#0000FF'>return</font> num_right <font color='#5555FF'>/</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>num_right <font color='#5555FF'>+</font> num_wrong<font face='Lucida Console'>)</font>; <b>}</b> <font color='#009900'>// ---------------------------------------------------------------------------------------- </font> <font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> <b>{</b> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'><</font> <font color='#979000'>2</font> <font color='#5555FF'>|</font><font color='#5555FF'>|</font> argc <font color='#5555FF'>></font> <font color='#979000'>3</font><font face='Lucida Console'>)</font> <b>{</b> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>To run this program you need a copy of the PASCAL VOC2012 dataset.</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>You call this program like this: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#0000FF'>return</font> <font color='#979000'>1</font>; <b>}</b> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\nSCANNING PASCAL VOC2012 DATASET\n</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_pascal_voc2012_train_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font>; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>images in dataset: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> <b>{</b> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Didn't find the VOC2012 dataset. </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#0000FF'>return</font> <font color='#979000'>1</font>; <b>}</b> <font color='#009900'>// a mini-batch smaller than the default can be used with GPUs having less memory </font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> minibatch_size <font color='#5555FF'>=</font> argc <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>3</font> ? std::<font color='#BB00BB'>stoi</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font> : <font color='#979000'>23</font>; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>mini-batch size: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> minibatch_size <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>; <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>; <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>; bnet_type bnet; dnn_trainer<font color='#5555FF'><</font>bnet_type<font color='#5555FF'>></font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>bnet,<font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>; trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>pascal_voc2012_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <font color='#009900'>// This threshold is probably excessively large. </font> trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>5000</font><font face='Lucida Console'>)</font>; <font color='#009900'>// Since the progress threshold is so large might as well set the batch normalization </font> <font color='#009900'>// stats window to something big too. </font> <font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>bnet, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>; <font color='#009900'>// Output training parameters. </font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl <font color='#5555FF'><</font><font color='#5555FF'><</font> trainer <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> samples; std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font><font color='#5555FF'>></font> labels; <font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops. It's </font> <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy. Using multiple </font> <font color='#009900'>// thread for this kind of data preparation helps us do that. Each thread puts the </font> <font color='#009900'>// crops into the data queue. </font> dlib::pipe<font color='#5555FF'><</font>training_sample<font color='#5555FF'>></font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>; <font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>data, <font color='#5555FF'>&</font>listing]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font> <b>{</b> dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font color='#5555FF'>+</font>seed<font face='Lucida Console'>)</font>; matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> rgb_label_image; matrix<font color='#5555FF'><</font>uint16_t<font color='#5555FF'>></font> index_label_image; training_sample temp; <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#009900'>// Pick a random input image. </font> <font color='#0000FF'>const</font> image_info<font color='#5555FF'>&</font> image_info <font color='#5555FF'>=</font> listing[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>]; <font color='#009900'>// Load the input image. </font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, image_info.image_filename<font face='Lucida Console'>)</font>; <font color='#009900'>// Load the ground-truth (RGB) labels. </font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>rgb_label_image, image_info.class_label_filename<font face='Lucida Console'>)</font>; <font color='#009900'>// Convert the RGB values to indexes. </font> <font color='#BB00BB'>rgb_label_image_to_index_label_image</font><font face='Lucida Console'>(</font>rgb_label_image, index_label_image<font face='Lucida Console'>)</font>; <font color='#009900'>// Randomly pick a part of the image. </font> <font color='#BB00BB'>randomly_crop_image</font><font face='Lucida Console'>(</font>input_image, index_label_image, temp, rnd<font face='Lucida Console'>)</font>; <font color='#009900'>// Push the result to be used by the trainer. </font> data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; <b>}</b> <b>}</b>; std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; <font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer. </font> <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-4. </font> <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font> <b>{</b> samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#009900'>// make a mini-batch </font> training_sample temp; <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> minibatch_size<font face='Lucida Console'>)</font> <b>{</b> data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.input_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.label_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; <b>}</b> trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>; <b>}</b> <font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before </font> <font color='#009900'>// moving on. </font> data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#009900'>// also wait for threaded processing to stop in the trainer. </font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; bnet.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>saving network</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>semantic_segmentation_net_filename<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> bnet; <font color='#009900'>// Make a copy of the network to use it for inference. </font> anet_type anet <font color='#5555FF'>=</font> bnet; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Testing the network...</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <font color='#009900'>// Find the accuracy of the newly trained network on both the training and the validation sets. </font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>train accuracy : </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>calculate_accuracy</font><font face='Lucida Console'>(</font>anet, <font color='#BB00BB'>get_pascal_voc2012_train_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>val accuracy : </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>calculate_accuracy</font><font face='Lucida Console'>(</font>anet, <font color='#BB00BB'>get_pascal_voc2012_val_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <b>}</b> <font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font> <b>{</b> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; <b>}</b> </pre></body></html>