// Copyright (C) 2018 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AUTO_LEARnING_CPP_ #define DLIB_AUTO_LEARnING_CPP_ #include "auto.h" #include "../global_optimization.h" #include "svm_c_trainer.h" #include <iostream> #include <thread> namespace dlib { normalized_function<decision_function<radial_basis_kernel<matrix<double,0,1>>>> auto_train_rbf_classifier ( std::vector<matrix<double,0,1>> x, std::vector<double> y, const std::chrono::nanoseconds max_runtime, bool be_verbose ) { const auto num_positive_training_samples = sum(mat(y)>0); const auto num_negative_training_samples = sum(mat(y)<0); DLIB_CASSERT(num_positive_training_samples >= 6 && num_negative_training_samples >= 6, "You must provide at least 6 examples of each class to this training routine."); // make sure requires clause is not broken DLIB_CASSERT(is_binary_classification_problem(x,y) == true, "\tdecision_function svm_c_trainer::train(x,y)" << "\n\t invalid inputs were given to this function" << "\n\t x.size(): " << x.size() << "\n\t y.size(): " << y.size() << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) ); randomize_samples(x,y); vector_normalizer<matrix<double,0,1>> normalizer; // let the normalizer learn the mean and standard deviation of the samples normalizer.train(x); for (auto& samp : x) samp = normalizer(samp); normalized_function<decision_function<radial_basis_kernel<matrix<double,0,1>>>> df; df.normalizer = normalizer; typedef radial_basis_kernel<matrix<double,0,1>> kernel_type; std::mutex m; auto cross_validation_score = [&](const double gamma, const double c1, const double c2) { svm_c_trainer<kernel_type> trainer; trainer.set_kernel(kernel_type(gamma)); trainer.set_c_class1(c1); trainer.set_c_class2(c2); // Finally, perform 6-fold cross validation and then print and return the results. matrix<double> result = cross_validate_trainer(trainer, x, y, 6); if (be_verbose) { std::lock_guard<std::mutex> lock(m); std::cout << "gamma: " << std::setw(11) << gamma << " c1: " << std::setw(11) << c1 << " c2: " << std::setw(11) << c2 << " cross validation accuracy: " << result << std::flush; } // return the f1 score plus a penalty for picking large parameter settings // since those are, a priori less likely to generalize. return 2*prod(result)/sum(result) - std::max(c1,c2)/1e12 - gamma/1e8; }; if (be_verbose) std::cout << "Searching for best RBF-SVM training parameters..." << std::endl; auto result = find_max_global( default_thread_pool(), cross_validation_score, {1e-5, 1e-5, 1e-5}, // lower bound constraints on gamma, c1, and c2, respectively {100, 1e6, 1e6}, // upper bound constraints on gamma, c1, and c2, respectively max_runtime); double best_gamma = result.x(0); double best_c1 = result.x(1); double best_c2 = result.x(2); if (be_verbose) { std::cout << " best cross-validation score: " << result.y << std::endl; std::cout << " best gamma: " << best_gamma << " best c1: " << best_c1 << " best c2: "<< best_c2 << std::endl; } svm_c_trainer<kernel_type> trainer; trainer.set_kernel(kernel_type(best_gamma)); trainer.set_c_class1(best_c1); trainer.set_c_class2(best_c2); if (be_verbose) std::cout << "Training final classifier with best parameters..." << std::endl; df.function = trainer.train(x,y); return df; } } #endif // DLIB_AUTO_LEARnING_CPP_