|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "dnn_semantic_segmentation_ex.h" |
|
|
|
#include <iostream> |
|
#include <dlib/data_io.h> |
|
#include <dlib/image_transforms.h> |
|
#include <dlib/dir_nav.h> |
|
#include <iterator> |
|
#include <thread> |
|
|
|
using namespace std; |
|
using namespace dlib; |
|
|
|
|
|
struct training_sample |
|
{ |
|
matrix<rgb_pixel> input_image; |
|
matrix<uint16_t> label_image; |
|
}; |
|
|
|
|
|
|
|
rectangle make_random_cropping_rect( |
|
const matrix<rgb_pixel>& img, |
|
dlib::rand& rnd |
|
) |
|
{ |
|
|
|
double mins = 0.466666666, maxs = 0.875; |
|
auto scale = mins + rnd.get_random_double()*(maxs-mins); |
|
auto size = scale*std::min(img.nr(), img.nc()); |
|
rectangle rect(size, size); |
|
|
|
point offset(rnd.get_random_32bit_number()%(img.nc()-rect.width()), |
|
rnd.get_random_32bit_number()%(img.nr()-rect.height())); |
|
return move_rect(rect, offset); |
|
} |
|
|
|
|
|
|
|
void randomly_crop_image ( |
|
const matrix<rgb_pixel>& input_image, |
|
const matrix<uint16_t>& label_image, |
|
training_sample& crop, |
|
dlib::rand& rnd |
|
) |
|
{ |
|
const auto rect = make_random_cropping_rect(input_image, rnd); |
|
|
|
const chip_details chip_details(rect, chip_dims(227, 227)); |
|
|
|
|
|
extract_image_chip(input_image, chip_details, crop.input_image, interpolate_bilinear()); |
|
|
|
|
|
|
|
|
|
extract_image_chip(label_image, chip_details, crop.label_image, interpolate_nearest_neighbor()); |
|
|
|
|
|
if (rnd.get_random_double() > 0.5) |
|
{ |
|
crop.input_image = fliplr(crop.input_image); |
|
crop.label_image = fliplr(crop.label_image); |
|
} |
|
|
|
|
|
apply_random_color_offset(crop.input_image, rnd); |
|
} |
|
|
|
|
|
|
|
|
|
double calculate_accuracy(anet_type& anet, const std::vector<image_info>& dataset) |
|
{ |
|
int num_right = 0; |
|
int num_wrong = 0; |
|
|
|
matrix<rgb_pixel> input_image; |
|
matrix<rgb_pixel> rgb_label_image; |
|
matrix<uint16_t> index_label_image; |
|
matrix<uint16_t> net_output; |
|
|
|
for (const auto& image_info : dataset) |
|
{ |
|
|
|
load_image(input_image, image_info.image_filename); |
|
|
|
|
|
load_image(rgb_label_image, image_info.class_label_filename); |
|
|
|
|
|
|
|
|
|
const matrix<uint16_t> temp = anet(input_image); |
|
|
|
|
|
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image); |
|
|
|
|
|
const chip_details chip_details( |
|
centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()), |
|
chip_dims(input_image.nr(), input_image.nc()) |
|
); |
|
extract_image_chip(temp, chip_details, net_output, interpolate_nearest_neighbor()); |
|
|
|
const long nr = index_label_image.nr(); |
|
const long nc = index_label_image.nc(); |
|
|
|
|
|
for (long r = 0; r < nr; ++r) |
|
{ |
|
for (long c = 0; c < nc; ++c) |
|
{ |
|
const uint16_t truth = index_label_image(r, c); |
|
if (truth != dlib::loss_multiclass_log_per_pixel_::label_to_ignore) |
|
{ |
|
const uint16_t prediction = net_output(r, c); |
|
if (prediction == truth) |
|
{ |
|
++num_right; |
|
} |
|
else |
|
{ |
|
++num_wrong; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
return num_right / static_cast<double>(num_right + num_wrong); |
|
} |
|
|
|
|
|
|
|
int main(int argc, char** argv) try |
|
{ |
|
if (argc < 2 || argc > 3) |
|
{ |
|
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl; |
|
cout << endl; |
|
cout << "You call this program like this: " << endl; |
|
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]" << endl; |
|
return 1; |
|
} |
|
|
|
cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl; |
|
|
|
const auto listing = get_pascal_voc2012_train_listing(argv[1]); |
|
cout << "images in dataset: " << listing.size() << endl; |
|
if (listing.size() == 0) |
|
{ |
|
cout << "Didn't find the VOC2012 dataset. " << endl; |
|
return 1; |
|
} |
|
|
|
|
|
const unsigned int minibatch_size = argc == 3 ? std::stoi(argv[2]) : 23; |
|
cout << "mini-batch size: " << minibatch_size << endl; |
|
|
|
const double initial_learning_rate = 0.1; |
|
const double weight_decay = 0.0001; |
|
const double momentum = 0.9; |
|
|
|
bnet_type bnet; |
|
dnn_trainer<bnet_type> trainer(bnet,sgd(weight_decay, momentum)); |
|
trainer.be_verbose(); |
|
trainer.set_learning_rate(initial_learning_rate); |
|
trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10)); |
|
|
|
trainer.set_iterations_without_progress_threshold(5000); |
|
|
|
|
|
set_all_bn_running_stats_window_sizes(bnet, 1000); |
|
|
|
|
|
cout << endl << trainer << endl; |
|
|
|
std::vector<matrix<rgb_pixel>> samples; |
|
std::vector<matrix<uint16_t>> labels; |
|
|
|
|
|
|
|
|
|
|
|
dlib::pipe<training_sample> data(200); |
|
auto f = [&data, &listing](time_t seed) |
|
{ |
|
dlib::rand rnd(time(0)+seed); |
|
matrix<rgb_pixel> input_image; |
|
matrix<rgb_pixel> rgb_label_image; |
|
matrix<uint16_t> index_label_image; |
|
training_sample temp; |
|
while(data.is_enabled()) |
|
{ |
|
|
|
const image_info& image_info = listing[rnd.get_random_32bit_number()%listing.size()]; |
|
|
|
|
|
load_image(input_image, image_info.image_filename); |
|
|
|
|
|
load_image(rgb_label_image, image_info.class_label_filename); |
|
|
|
|
|
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image); |
|
|
|
|
|
randomly_crop_image(input_image, index_label_image, temp, rnd); |
|
|
|
|
|
data.enqueue(temp); |
|
} |
|
}; |
|
std::thread data_loader1([f](){ f(1); }); |
|
std::thread data_loader2([f](){ f(2); }); |
|
std::thread data_loader3([f](){ f(3); }); |
|
std::thread data_loader4([f](){ f(4); }); |
|
|
|
|
|
|
|
while(trainer.get_learning_rate() >= 1e-4) |
|
{ |
|
samples.clear(); |
|
labels.clear(); |
|
|
|
|
|
training_sample temp; |
|
while(samples.size() < minibatch_size) |
|
{ |
|
data.dequeue(temp); |
|
|
|
samples.push_back(std::move(temp.input_image)); |
|
labels.push_back(std::move(temp.label_image)); |
|
} |
|
|
|
trainer.train_one_step(samples, labels); |
|
} |
|
|
|
|
|
|
|
data.disable(); |
|
data_loader1.join(); |
|
data_loader2.join(); |
|
data_loader3.join(); |
|
data_loader4.join(); |
|
|
|
|
|
trainer.get_net(); |
|
|
|
bnet.clean(); |
|
cout << "saving network" << endl; |
|
serialize(semantic_segmentation_net_filename) << bnet; |
|
|
|
|
|
|
|
anet_type anet = bnet; |
|
|
|
cout << "Testing the network..." << endl; |
|
|
|
|
|
cout << "train accuracy : " << calculate_accuracy(anet, get_pascal_voc2012_train_listing(argv[1])) << endl; |
|
cout << "val accuracy : " << calculate_accuracy(anet, get_pascal_voc2012_val_listing(argv[1])) << endl; |
|
} |
|
catch(std::exception& e) |
|
{ |
|
cout << e.what() << endl; |
|
} |
|
|
|
|