#include "dnn_instance_segmentation_ex.h" |
#include "pascal_voc_2012.h" |
#include <iostream> |
#include <dlib/data_io.h> |
#include <dlib/image_transforms.h> |
#include <dlib/dir_nav.h> |
#include <iterator> |
#include <thread> |
#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) |
#include <execution> |
#endif |
using namespace std; |
using namespace dlib; |
struct det_training_sample |
{ |
matrix<rgb_pixel> input_image; |
std::vector<dlib::mmod_rect> mmod_rects; |
}; |
struct seg_training_sample |
{ |
matrix<rgb_pixel> input_image; |
matrix<float> label_image; |
}; |
bool is_instance_pixel(const dlib::rgb_pixel& rgb_label) |
{ |
if (rgb_label == dlib::rgb_pixel(0, 0, 0)) |
return false; |
if (rgb_label == dlib::rgb_pixel(224, 224, 192)) |
return false; |
return true; |
} |
namespace std { |
template <> |
struct hash<dlib::rgb_pixel> |
{ |
std::size_t operator()(const dlib::rgb_pixel& p) const |
{ |
return (static_cast<uint32_t>(p.red) << 16) |
| (static_cast<uint32_t>(p.green) << 8) |
| (static_cast<uint32_t>(p.blue)); |
} |
}; |
} |
struct truth_instance |
{ |
dlib::rgb_pixel rgb_label; |
dlib::mmod_rect mmod_rect; |
}; |
std::vector<truth_instance> rgb_label_images_to_truth_instances( |
const dlib::matrix<dlib::rgb_pixel>& instance_label_image, |
const dlib::matrix<dlib::rgb_pixel>& class_label_image |
) |
{ |
std::unordered_map<dlib::rgb_pixel, mmod_rect> result_map; |
DLIB_CASSERT(instance_label_image.nr() == class_label_image.nr()); |
DLIB_CASSERT(instance_label_image.nc() == class_label_image.nc()); |
const auto nr = instance_label_image.nr(); |
const auto nc = instance_label_image.nc(); |
for (int r = 0; r < nr; ++r) |
{ |
for (int c = 0; c < nc; ++c) |
{ |
const auto rgb_instance_label = instance_label_image(r, c); |
if (!is_instance_pixel(rgb_instance_label)) |
continue; |
const auto rgb_class_label = class_label_image(r, c); |
const Voc2012class& voc2012_class = find_voc2012_class(rgb_class_label); |
const auto i = result_map.find(rgb_instance_label); |
if (i == result_map.end()) |
{ |
result_map[rgb_instance_label] = rectangle(c, r, c, r); |
result_map[rgb_instance_label].label = voc2012_class.classlabel; |
} |
else |
{ |
auto& rect = i->second.rect; |
if (c < rect.left()) |
rect.set_left(c); |
else if (c > rect.right()) |
rect.set_right(c); |
if (r > rect.bottom()) |
rect.set_bottom(r); |
DLIB_CASSERT(i->second.label == voc2012_class.classlabel); |
} |
} |
} |
std::vector<truth_instance> flat_result; |
flat_result.reserve(result_map.size()); |
for (const auto& i : result_map) { |
flat_result.push_back(truth_instance{ |
i.first, i.second |
}); |
} |
return flat_result; |
} |
struct truth_image |
{ |
image_info info; |
std::vector<truth_instance> truth_instances; |
}; |
std::vector<mmod_rect> extract_mmod_rects( |
const std::vector<truth_instance>& truth_instances |
) |
{ |
std::vector<mmod_rect> mmod_rects(truth_instances.size()); |
std::transform( |
truth_instances.begin(), |
truth_instances.end(), |
mmod_rects.begin(), |
[](const truth_instance& truth) { return truth.mmod_rect; } |
); |
return mmod_rects; |
} |
std::vector<std::vector<mmod_rect>> extract_mmod_rect_vectors( |
const std::vector<truth_image>& truth_images |
) |
{ |
std::vector<std::vector<mmod_rect>> mmod_rects(truth_images.size()); |
const auto extract_mmod_rects_from_truth_image = [](const truth_image& truth_image) |
{ |
return extract_mmod_rects(truth_image.truth_instances); |
}; |
std::transform( |
truth_images.begin(), |
truth_images.end(), |
mmod_rects.begin(), |
extract_mmod_rects_from_truth_image |
); |
return mmod_rects; |
} |
det_bnet_type train_detection_network( |
const std::vector<truth_image>& truth_images, |
unsigned int det_minibatch_size |
) |
{ |
const double initial_learning_rate = 0.1; |
const double weight_decay = 0.0001; |
const double momentum = 0.9; |
const double min_detector_window_overlap_iou = 0.65; |
const int target_size = 70; |
const int min_target_size = 30; |
mmod_options options( |
extract_mmod_rect_vectors(truth_images), |
target_size, min_target_size, |
min_detector_window_overlap_iou |
); |
options.overlaps_ignore = test_box_overlap(0.5, 0.9); |
det_bnet_type det_net(options); |
det_net.subnet().layer_details().set_num_filters(options.detector_windows.size()); |
dlib::pipe<det_training_sample> data(200); |
auto f = [&data, &truth_images, target_size, min_target_size](time_t seed) |
{ |
dlib::rand rnd(time(0) + seed); |
matrix<rgb_pixel> input_image; |
random_cropper cropper; |
cropper.set_seed(time(0)); |
cropper.set_chip_dims(350, 350); |
cropper.set_min_object_size(target_size - 2, min_target_size - 2); |
cropper.set_max_rotation_degrees(2); |
det_training_sample temp; |
while (data.is_enabled()) |
{ |
const auto random_index = rnd.get_random_32bit_number() % truth_images.size(); |
const auto& truth_image = truth_images[random_index]; |
load_image(input_image, truth_image.info.image_filename); |
const auto mmod_rects = extract_mmod_rects(truth_image.truth_instances); |
cropper(input_image, mmod_rects, temp.input_image, temp.mmod_rects); |
disturb_colors(temp.input_image, rnd); |
data.enqueue(temp); |
} |
}; |
std::thread data_loader1([f]() { f(1); }); |
std::thread data_loader2([f]() { f(2); }); |
std::thread data_loader3([f]() { f(3); }); |
std::thread data_loader4([f]() { f(4); }); |
const auto stop_data_loaders = [&]() |
{ |
data.disable(); |
data_loader1.join(); |
data_loader2.join(); |
data_loader3.join(); |
data_loader4.join(); |
}; |
dnn_trainer<det_bnet_type> det_trainer(det_net, sgd(weight_decay, momentum)); |
try |
{ |
det_trainer.be_verbose(); |
det_trainer.set_learning_rate(initial_learning_rate); |
det_trainer.set_synchronization_file("pascal_voc2012_det_trainer_state_file.dat", std::chrono::minutes(10)); |
det_trainer.set_iterations_without_progress_threshold(5000); |
cout << det_trainer << endl; |
std::vector<matrix<rgb_pixel>> samples; |
std::vector<std::vector<mmod_rect>> labels; |
while (det_trainer.get_learning_rate() >= 1e-4) |
{ |
samples.clear(); |
labels.clear(); |
det_training_sample temp; |
while (samples.size() < det_minibatch_size) |
{ |
data.dequeue(temp); |
samples.push_back(std::move(temp.input_image)); |
labels.push_back(std::move(temp.mmod_rects)); |
} |
det_trainer.train_one_step(samples, labels); |
} |
} |
catch (std::exception&) |
{ |
stop_data_loaders(); |
throw; |
} |
stop_data_loaders(); |
det_trainer.get_net(); |
det_net.clean(); |
return det_net; |
} |
matrix<float> keep_only_current_instance(const matrix<rgb_pixel>& rgb_label_image, const rgb_pixel rgb_label) |
{ |
const auto nr = rgb_label_image.nr(); |
const auto nc = rgb_label_image.nc(); |
matrix<float> result(nr, nc); |
for (long r = 0; r < nr; ++r) |
{ |
for (long c = 0; c < nc; ++c) |
{ |
const auto& index = rgb_label_image(r, c); |
if (index == rgb_label) |
result(r, c) = +1; |
else if (index == dlib::rgb_pixel(224, 224, 192)) |
result(r, c) = 0; |
else |
result(r, c) = -1; |
} |
} |
return result; |
} |
seg_bnet_type train_segmentation_network( |
const std::vector<truth_image>& truth_images, |
unsigned int seg_minibatch_size, |
const std::string& classlabel |
) |
{ |
seg_bnet_type seg_net; |
const double initial_learning_rate = 0.1; |
const double weight_decay = 0.0001; |
const double momentum = 0.9; |
const std::string synchronization_file_name |
= "pascal_voc2012_seg_trainer_state_file" |
+ (classlabel.empty() ? "" : ("_" + classlabel)) |
+ ".dat"; |
dnn_trainer<seg_bnet_type> seg_trainer(seg_net, sgd(weight_decay, momentum)); |
seg_trainer.be_verbose(); |
seg_trainer.set_learning_rate(initial_learning_rate); |
seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10)); |
seg_trainer.set_iterations_without_progress_threshold(2000); |
set_all_bn_running_stats_window_sizes(seg_net, 1000); |
cout << seg_trainer << endl; |
std::vector<matrix<rgb_pixel>> samples; |
std::vector<matrix<float>> labels; |
dlib::pipe<seg_training_sample> data(200); |
auto f = [&data, &truth_images](time_t seed) |
{ |
dlib::rand rnd(time(0) + seed); |
matrix<rgb_pixel> input_image; |
matrix<rgb_pixel> rgb_label_image; |
matrix<rgb_pixel> rgb_label_chip; |
seg_training_sample temp; |
while (data.is_enabled()) |
{ |
const auto random_index = rnd.get_random_32bit_number() % truth_images.size(); |
const auto& truth_image = truth_images[random_index]; |
const auto image_truths = truth_image.truth_instances; |
if (!image_truths.empty()) |
{ |
const image_info& info = truth_image.info; |
load_image(input_image, info.image_filename); |
load_image(rgb_label_image, info.instance_label_filename); |
const auto& truth_instance = image_truths[rnd.get_random_32bit_number() % image_truths.size()]; |
const auto& truth_rect = truth_instance.mmod_rect.rect; |
const auto cropping_rect = get_cropping_rect(truth_rect); |
const auto max_x_translate_amount = static_cast<long>(truth_rect.width() / 10.0); |
const auto max_y_translate_amount = static_cast<long>(truth_rect.height() / 10.0); |
const auto random_translate = point( |
rnd.get_integer_in_range(-max_x_translate_amount, max_x_translate_amount + 1), |
rnd.get_integer_in_range(-max_y_translate_amount, max_y_translate_amount + 1) |
); |
const rectangle random_rect( |
cropping_rect.left() + random_translate.x(), |
cropping_rect.top() + random_translate.y(), |
cropping_rect.right() + random_translate.x(), |
cropping_rect.bottom() + random_translate.y() |
); |
const chip_details chip_details(random_rect, chip_dims(seg_dim, seg_dim)); |
extract_image_chip(input_image, chip_details, temp.input_image, interpolate_bilinear()); |
disturb_colors(temp.input_image, rnd); |
extract_image_chip(rgb_label_image, chip_details, rgb_label_chip, interpolate_nearest_neighbor()); |
temp.label_image = keep_only_current_instance(rgb_label_chip, truth_instance.rgb_label); |
data.enqueue(temp); |
} |
else |
{ |
} |
} |
}; |
std::thread data_loader1([f]() { f(1); }); |
std::thread data_loader2([f]() { f(2); }); |
std::thread data_loader3([f]() { f(3); }); |
std::thread data_loader4([f]() { f(4); }); |
const auto stop_data_loaders = [&]() |
{ |
data.disable(); |
data_loader1.join(); |
data_loader2.join(); |
data_loader3.join(); |
data_loader4.join(); |
}; |
try |
{ |
while (seg_trainer.get_learning_rate() >= 1e-4) |
{ |
samples.clear(); |
labels.clear(); |
seg_training_sample temp; |
while (samples.size() < seg_minibatch_size) |
{ |
data.dequeue(temp); |
samples.push_back(std::move(temp.input_image)); |
labels.push_back(std::move(temp.label_image)); |
} |
seg_trainer.train_one_step(samples, labels); |
} |
} |
catch (std::exception&) |
{ |
stop_data_loaders(); |
throw; |
} |
stop_data_loaders(); |
seg_trainer.get_net(); |
seg_net.clean(); |
return seg_net; |
} |
int ignore_overlapped_boxes( |
std::vector<truth_instance>& truth_instances, |
const test_box_overlap& overlaps |
) |
{ |
int num_ignored = 0; |
for (size_t i = 0, end = truth_instances.size(); i < end; ++i) |
{ |
auto& box_i = truth_instances[i].mmod_rect; |
if (box_i.ignore) |
continue; |
for (size_t j = i+1; j < end; ++j) |
{ |
auto& box_j = truth_instances[j].mmod_rect; |
if (box_j.ignore) |
continue; |
if (overlaps(box_i, box_j)) |
{ |
++num_ignored; |
if(box_i.rect.area() < box_j.rect.area()) |
box_i.ignore = true; |
else |
box_j.ignore = true; |
} |
} |
} |
return num_ignored; |
} |
std::vector<truth_instance> load_truth_instances(const image_info& info) |
{ |
matrix<rgb_pixel> instance_label_image; |
matrix<rgb_pixel> class_label_image; |
load_image(instance_label_image, info.instance_label_filename); |
load_image(class_label_image, info.class_label_filename); |
return rgb_label_images_to_truth_instances(instance_label_image, class_label_image); |
} |
std::vector<std::vector<truth_instance>> load_all_truth_instances(const std::vector<image_info>& listing) |
{ |
std::vector<std::vector<truth_instance>> truth_instances(listing.size()); |
std::transform( |
#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) |
std::execution::par, |
#endif |
listing.begin(), |
listing.end(), |
truth_instances.begin(), |
load_truth_instances |
); |
return truth_instances; |
} |
std::vector<truth_image> filter_based_on_classlabel( |
const std::vector<truth_image>& truth_images, |
const std::vector<std::string>& desired_classlabels |
) |
{ |
std::vector<truth_image> result; |
const auto represents_desired_class = [&desired_classlabels](const truth_instance& truth_instance) { |
return std::find( |
desired_classlabels.begin(), |
desired_classlabels.end(), |
truth_instance.mmod_rect.label |
) != desired_classlabels.end(); |
}; |
for (const auto& input : truth_images) |
{ |
const auto has_desired_class = std::any_of( |
input.truth_instances.begin(), |
input.truth_instances.end(), |
represents_desired_class |
); |
if (has_desired_class) { |
std::vector<truth_instance> temp; |
std::copy_if( |
input.truth_instances.begin(), |
input.truth_instances.end(), |
std::back_inserter(temp), |
represents_desired_class |
); |
result.push_back(truth_image{ input.info, temp }); |
} |
} |
return result; |
} |
void ignore_some_truth_boxes(std::vector<truth_image>& truth_images) |
{ |
for (auto& i : truth_images) |
{ |
auto& truth_instances = i.truth_instances; |
ignore_overlapped_boxes(truth_instances, test_box_overlap(0.90, 0.95)); |
for (auto& truth : truth_instances) |
{ |
if (truth.mmod_rect.ignore) |
continue; |
const auto& rect = truth.mmod_rect.rect; |
constexpr unsigned long min_width = 35; |
constexpr unsigned long min_height = 35; |
if (rect.width() < min_width && rect.height() < min_height) |
{ |
truth.mmod_rect.ignore = true; |
continue; |
} |
constexpr double max_aspect_ratio_width_to_height = 3.0; |
constexpr double max_aspect_ratio_height_to_width = 1.5; |
const double aspect_ratio_width_to_height = rect.width() / static_cast<double>(rect.height()); |
const double aspect_ratio_height_to_width = 1.0 / aspect_ratio_width_to_height; |
const bool is_aspect_ratio_too_large |
= aspect_ratio_width_to_height > max_aspect_ratio_width_to_height |
|| aspect_ratio_height_to_width > max_aspect_ratio_height_to_width; |
if (is_aspect_ratio_too_large) |
truth.mmod_rect.ignore = true; |
} |
} |
} |
std::vector<truth_image> filter_images_with_no_truth(const std::vector<truth_image>& truth_images) |
{ |
std::vector<truth_image> result; |
for (const auto& truth_image : truth_images) |
{ |
const auto ignored = [](const truth_instance& truth) { return truth.mmod_rect.ignore; }; |
const auto& truth_instances = truth_image.truth_instances; |
if (!std::all_of(truth_instances.begin(), truth_instances.end(), ignored)) |
result.push_back(truth_image); |
} |
return result; |
} |
int main(int argc, char** argv) try |
{ |
if (argc < 2) |
{ |
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl; |
cout << endl; |
cout << "You call this program like this: " << endl; |
cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-1] [class-2] [class-3] ..." << endl; |
return 1; |
} |
cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl; |
const auto listing = get_pascal_voc2012_train_listing(argv[1]); |
cout << "images in entire dataset: " << listing.size() << endl; |
if (listing.size() == 0) |
{ |
cout << "Didn't find the VOC2012 dataset. " << endl; |
return 1; |
} |
const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 35; |
const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 100; |
cout << "det mini-batch size: " << det_minibatch_size << endl; |
cout << "seg mini-batch size: " << seg_minibatch_size << endl; |
std::vector<std::string> desired_classlabels; |
for (int arg = 4; arg < argc; ++arg) |
desired_classlabels.push_back(argv[arg]); |
if (desired_classlabels.empty()) |
{ |
desired_classlabels.push_back("bicycle"); |
desired_classlabels.push_back("car"); |
desired_classlabels.push_back("cat"); |
} |
cout << "desired classlabels:"; |
for (const auto& desired_classlabel : desired_classlabels) |
cout << " " << desired_classlabel; |
cout << endl; |
cout << endl << "Extracting all truth instances..."; |
const auto truth_instances = load_all_truth_instances(listing); |
cout << " Done!" << endl << endl; |
DLIB_CASSERT(listing.size() == truth_instances.size()); |
std::vector<truth_image> original_truth_images; |
for (size_t i = 0, end = listing.size(); i < end; ++i) |
{ |
original_truth_images.push_back(truth_image{ |
listing[i], truth_instances[i] |
}); |
} |
auto truth_images_filtered_by_class = filter_based_on_classlabel(original_truth_images, desired_classlabels); |
cout << "images in dataset filtered by class: " << truth_images_filtered_by_class.size() << endl; |
ignore_some_truth_boxes(truth_images_filtered_by_class); |
const auto truth_images = filter_images_with_no_truth(truth_images_filtered_by_class); |
cout << "images in dataset after ignoring some truth boxes: " << truth_images.size() << endl; |
cout << endl << "Training detector network:" << endl; |
const auto det_net = train_detection_network(truth_images, det_minibatch_size); |
std::map<std::string, seg_bnet_type> seg_nets_by_class; |
constexpr bool separate_seg_net_for_each_class = true; |
if (separate_seg_net_for_each_class) |
{ |
for (const auto& classlabel : desired_classlabels) |
{ |
const auto class_images = filter_based_on_classlabel(truth_images, { classlabel }); |
cout << endl << "Training segmentation network for class " << classlabel << ":" << endl; |
seg_nets_by_class[classlabel] = train_segmentation_network(class_images, seg_minibatch_size, classlabel); |
} |
} |
else |
{ |
cout << "Training a single segmentation network:" << endl; |
seg_nets_by_class[""] = train_segmentation_network(truth_images, seg_minibatch_size, ""); |
} |
cout << "Saving networks" << endl; |
serialize(instance_segmentation_net_filename) << det_net << seg_nets_by_class; |
} |
catch(std::exception& e) |
{ |
cout << e.what() << endl; |
} |