File size: 10,701 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This example shows how to train a semantic segmentation net using the PASCAL VOC2012
dataset. For an introduction to what segmentation is, see the accompanying header file
dnn_semantic_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_semantic_segmentation_train_ex example program.
3. Run:
./dnn_semantic_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_semantic_segmentation_ex example program.
6. Run:
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#include "dnn_semantic_segmentation_ex.h"
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/image_transforms.h>
#include <dlib/dir_nav.h>
#include <iterator>
#include <thread>
using namespace std;
using namespace dlib;
// A single training sample. A mini-batch comprises many of these.
struct training_sample
{
matrix<rgb_pixel> input_image;
matrix<uint16_t> label_image; // The ground-truth label of each pixel.
};
// ----------------------------------------------------------------------------------------
rectangle make_random_cropping_rect(
const matrix<rgb_pixel>& img,
dlib::rand& rnd
)
{
// figure out what rectangle we want to crop from the image
double mins = 0.466666666, maxs = 0.875;
auto scale = mins + rnd.get_random_double()*(maxs-mins);
auto size = scale*std::min(img.nr(), img.nc());
rectangle rect(size, size);
// randomly shift the box around
point offset(rnd.get_random_32bit_number()%(img.nc()-rect.width()),
rnd.get_random_32bit_number()%(img.nr()-rect.height()));
return move_rect(rect, offset);
}
// ----------------------------------------------------------------------------------------
void randomly_crop_image (
const matrix<rgb_pixel>& input_image,
const matrix<uint16_t>& label_image,
training_sample& crop,
dlib::rand& rnd
)
{
const auto rect = make_random_cropping_rect(input_image, rnd);
const chip_details chip_details(rect, chip_dims(227, 227));
// Crop the input image.
extract_image_chip(input_image, chip_details, crop.input_image, interpolate_bilinear());
// Crop the labels correspondingly. However, note that here bilinear
// interpolation would make absolutely no sense - you wouldn't say that
// a bicycle is half-way between an aeroplane and a bird, would you?
extract_image_chip(label_image, chip_details, crop.label_image, interpolate_nearest_neighbor());
// Also randomly flip the input image and the labels.
if (rnd.get_random_double() > 0.5)
{
crop.input_image = fliplr(crop.input_image);
crop.label_image = fliplr(crop.label_image);
}
// And then randomly adjust the colors.
apply_random_color_offset(crop.input_image, rnd);
}
// ----------------------------------------------------------------------------------------
// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
double calculate_accuracy(anet_type& anet, const std::vector<image_info>& dataset)
{
int num_right = 0;
int num_wrong = 0;
matrix<rgb_pixel> input_image;
matrix<rgb_pixel> rgb_label_image;
matrix<uint16_t> index_label_image;
matrix<uint16_t> net_output;
for (const auto& image_info : dataset)
{
// Load the input image.
load_image(input_image, image_info.image_filename);
// Load the ground-truth (RGB) labels.
load_image(rgb_label_image, image_info.class_label_filename);
// Create predictions for each pixel. At this point, the type of each prediction
// is an index (a value between 0 and 20). Note that the net may return an image
// that is not exactly the same size as the input.
const matrix<uint16_t> temp = anet(input_image);
// Convert the RGB values to indexes.
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
// Crop the net output to be exactly the same size as the input.
const chip_details chip_details(
centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()),
chip_dims(input_image.nr(), input_image.nc())
);
extract_image_chip(temp, chip_details, net_output, interpolate_nearest_neighbor());
const long nr = index_label_image.nr();
const long nc = index_label_image.nc();
// Compare the predicted values to the ground-truth values.
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
const uint16_t truth = index_label_image(r, c);
if (truth != dlib::loss_multiclass_log_per_pixel_::label_to_ignore)
{
const uint16_t prediction = net_output(r, c);
if (prediction == truth)
{
++num_right;
}
else
{
++num_wrong;
}
}
}
}
}
// Return the accuracy estimate.
return num_right / static_cast<double>(num_right + num_wrong);
}
// ----------------------------------------------------------------------------------------
int main(int argc, char** argv) try
{
if (argc < 2 || argc > 3)
{
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
cout << endl;
cout << "You call this program like this: " << endl;
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]" << endl;
return 1;
}
cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl;
const auto listing = get_pascal_voc2012_train_listing(argv[1]);
cout << "images in dataset: " << listing.size() << endl;
if (listing.size() == 0)
{
cout << "Didn't find the VOC2012 dataset. " << endl;
return 1;
}
// a mini-batch smaller than the default can be used with GPUs having less memory
const unsigned int minibatch_size = argc == 3 ? std::stoi(argv[2]) : 23;
cout << "mini-batch size: " << minibatch_size << endl;
const double initial_learning_rate = 0.1;
const double weight_decay = 0.0001;
const double momentum = 0.9;
bnet_type bnet;
dnn_trainer<bnet_type> trainer(bnet,sgd(weight_decay, momentum));
trainer.be_verbose();
trainer.set_learning_rate(initial_learning_rate);
trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10));
// This threshold is probably excessively large.
trainer.set_iterations_without_progress_threshold(5000);
// Since the progress threshold is so large might as well set the batch normalization
// stats window to something big too.
set_all_bn_running_stats_window_sizes(bnet, 1000);
// Output training parameters.
cout << endl << trainer << endl;
std::vector<matrix<rgb_pixel>> samples;
std::vector<matrix<uint16_t>> labels;
// Start a bunch of threads that read images from disk and pull out random crops. It's
// important to be sure to feed the GPU fast enough to keep it busy. Using multiple
// thread for this kind of data preparation helps us do that. Each thread puts the
// crops into the data queue.
dlib::pipe<training_sample> data(200);
auto f = [&data, &listing](time_t seed)
{
dlib::rand rnd(time(0)+seed);
matrix<rgb_pixel> input_image;
matrix<rgb_pixel> rgb_label_image;
matrix<uint16_t> index_label_image;
training_sample temp;
while(data.is_enabled())
{
// Pick a random input image.
const image_info& image_info = listing[rnd.get_random_32bit_number()%listing.size()];
// Load the input image.
load_image(input_image, image_info.image_filename);
// Load the ground-truth (RGB) labels.
load_image(rgb_label_image, image_info.class_label_filename);
// Convert the RGB values to indexes.
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
// Randomly pick a part of the image.
randomly_crop_image(input_image, index_label_image, temp, rnd);
// Push the result to be used by the trainer.
data.enqueue(temp);
}
};
std::thread data_loader1([f](){ f(1); });
std::thread data_loader2([f](){ f(2); });
std::thread data_loader3([f](){ f(3); });
std::thread data_loader4([f](){ f(4); });
// The main training loop. Keep making mini-batches and giving them to the trainer.
// We will run until the learning rate has dropped by a factor of 1e-4.
while(trainer.get_learning_rate() >= 1e-4)
{
samples.clear();
labels.clear();
// make a mini-batch
training_sample temp;
while(samples.size() < minibatch_size)
{
data.dequeue(temp);
samples.push_back(std::move(temp.input_image));
labels.push_back(std::move(temp.label_image));
}
trainer.train_one_step(samples, labels);
}
// Training done, tell threads to stop and make sure to wait for them to finish before
// moving on.
data.disable();
data_loader1.join();
data_loader2.join();
data_loader3.join();
data_loader4.join();
// also wait for threaded processing to stop in the trainer.
trainer.get_net();
bnet.clean();
cout << "saving network" << endl;
serialize(semantic_segmentation_net_filename) << bnet;
// Make a copy of the network to use it for inference.
anet_type anet = bnet;
cout << "Testing the network..." << endl;
// Find the accuracy of the newly trained network on both the training and the validation sets.
cout << "train accuracy : " << calculate_accuracy(anet, get_pascal_voc2012_train_listing(argv[1])) << endl;
cout << "val accuracy : " << calculate_accuracy(anet, get_pascal_voc2012_val_listing(argv[1])) << endl;
}
catch(std::exception& e)
{
cout << e.what() << endl;
}
|