File size: 25,499 Bytes
9375c9a |
|
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - custom_trainer_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
This example program shows you how to create your own custom binary classification
trainer object and use it with the multiclass classification tools in the dlib C++
library. This example assumes you have already become familiar with the concepts
introduced in the <a href="multiclass_classification_ex.cpp.html">multiclass_classification_ex.cpp</a> example program.
In this example we will create a very simple trainer object that takes a binary
classification problem and produces a decision rule which says a test point has the
same class as whichever centroid it is closest to.
The multiclass training dataset will consist of four classes. Each class will be a blob
of points in one of the quadrants of the cartesian plane. For fun, we will use
std::string labels and therefore the labels of these classes will be the following:
"upper_left",
"upper_right",
"lower_left",
"lower_right"
*/</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>svm_threaded.h<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>vector<font color='#5555FF'>></font>
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>rand.h<font color='#5555FF'>></font>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;
<font color='#009900'>// Our data will be 2-dimensional data. So declare an appropriate type to contain these points.
</font><font color='#0000FF'>typedef</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>2</font>,<font color='#979000'>1</font><font color='#5555FF'>></font> sample_type;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>struct</font> <b><a name='custom_decision_function'></a>custom_decision_function</b>
<b>{</b>
<font color='#009900'>/*!
WHAT THIS OBJECT REPRESENTS
This object is the representation of our binary decision rule.
!*/</font>
<font color='#009900'>// centers of the two classes
</font> sample_type positive_center, negative_center;
<font color='#0000FF'><u>double</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> sample_type<font color='#5555FF'>&</font> x
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#009900'>// if x is closer to the positive class then return +1
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>length</font><font face='Lucida Console'>(</font>positive_center <font color='#5555FF'>-</font> x<font face='Lucida Console'>)</font> <font color='#5555FF'><</font> <font color='#BB00BB'>length</font><font face='Lucida Console'>(</font>negative_center <font color='#5555FF'>-</font> x<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> <font color='#5555FF'>+</font><font color='#979000'>1</font>;
<font color='#0000FF'>else</font>
<font color='#0000FF'>return</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>;
<b>}</b>
<b>}</b>;
<font color='#009900'>// Later on in this example we will save our decision functions to disk. This
</font><font color='#009900'>// pair of routines is needed for this functionality.
</font><font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> custom_decision_function<font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// write the state of item to the output stream
</font> <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.positive_center, out<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>item.negative_center, out<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b> <font face='Lucida Console'>(</font>custom_decision_function<font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// read the data from the input stream and store it in item
</font> <font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.positive_center, in<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>item.negative_center, in<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>class</font> <b><a name='simple_custom_trainer'></a>simple_custom_trainer</b>
<b>{</b>
<font color='#009900'>/*!
WHAT THIS OBJECT REPRESENTS
This is our example custom binary classifier trainer object. It simply
computes the means of the +1 and -1 classes, puts them into our
custom_decision_function, and returns the results.
Below we define the train() function. I have also included the
requires/ensures definition for a generic binary classifier's train()
!*/</font>
<font color='#0000FF'>public</font>:
custom_decision_function <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> samples,
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font color='#5555FF'>&</font> labels
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<font color='#009900'>/*!
requires
- is_binary_classification_problem(samples, labels) == true
(e.g. labels consists of only +1 and -1 values, samples.size() == labels.size())
ensures
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
- F(new_x) >= 0
- else
- F(new_x) < 0
!*/</font>
<b>{</b>
sample_type positive_center, negative_center;
<font color='#009900'>// compute sums of each class
</font> positive_center <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
negative_center <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>labels[i] <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>
positive_center <font color='#5555FF'>+</font><font color='#5555FF'>=</font> samples[i];
<font color='#0000FF'>else</font> <font color='#009900'>// this is a -1 sample
</font> negative_center <font color='#5555FF'>+</font><font color='#5555FF'>=</font> samples[i];
<b>}</b>
<font color='#009900'>// divide by number of +1 samples
</font> positive_center <font color='#5555FF'>/</font><font color='#5555FF'>=</font> <font color='#BB00BB'>sum</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font>labels<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// divide by number of -1 samples
</font> negative_center <font color='#5555FF'>/</font><font color='#5555FF'>=</font> <font color='#BB00BB'>sum</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font>labels<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
custom_decision_function df;
df.positive_center <font color='#5555FF'>=</font> positive_center;
df.negative_center <font color='#5555FF'>=</font> negative_center;
<font color='#0000FF'>return</font> df;
<b>}</b>
<b>}</b>;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='generate_data'></a>generate_data</b> <font face='Lucida Console'>(</font>
std::vector<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> samples,
std::vector<font color='#5555FF'><</font>string<font color='#5555FF'>></font><font color='#5555FF'>&</font> labels
<font face='Lucida Console'>)</font>;
<font color='#009900'>/*!
ensures
- make some four class data as described above.
- each class will have 50 samples in it
!*/</font>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font> samples;
std::vector<font color='#5555FF'><</font>string<font color='#5555FF'>></font> labels;
<font color='#009900'>// First, get our labeled set of training data
</font> <font color='#BB00BB'>generate_data</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>samples.size(): </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#009900'>// Define the trainer we will use. The second template argument specifies the type
</font> <font color='#009900'>// of label used, which is string in this case.
</font> <font color='#0000FF'>typedef</font> one_vs_one_trainer<font color='#5555FF'><</font>any_trainer<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font>, string<font color='#5555FF'>></font> ovo_trainer;
ovo_trainer trainer;
<font color='#009900'>// Now tell the one_vs_one_trainer that, by default, it should use the simple_custom_trainer
</font> <font color='#009900'>// to solve the individual binary classification subproblems.
</font> trainer.<font color='#BB00BB'>set_trainer</font><font face='Lucida Console'>(</font><font color='#BB00BB'>simple_custom_trainer</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Next, to make things a little more interesting, we will setup the one_vs_one_trainer
</font> <font color='#009900'>// to use kernel ridge regression to solve the upper_left vs lower_right binary classification
</font> <font color='#009900'>// subproblem.
</font> <font color='#0000FF'>typedef</font> radial_basis_kernel<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font> rbf_kernel;
krr_trainer<font color='#5555FF'><</font>rbf_kernel<font color='#5555FF'>></font> rbf_trainer;
rbf_trainer.<font color='#BB00BB'>set_kernel</font><font face='Lucida Console'>(</font><font color='#BB00BB'>rbf_kernel</font><font face='Lucida Console'>(</font><font color='#979000'>0.1</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
trainer.<font color='#BB00BB'>set_trainer</font><font face='Lucida Console'>(</font>rbf_trainer, "<font color='#CC0000'>upper_left</font>", "<font color='#CC0000'>lower_right</font>"<font face='Lucida Console'>)</font>;
<font color='#009900'>// Now let's do 5-fold cross-validation using the one_vs_one_trainer we just setup.
</font> <font color='#009900'>// As an aside, always shuffle the order of the samples before doing cross validation.
</font> <font color='#009900'>// For a discussion of why this is a good idea see the <a href="svm_ex.cpp.html">svm_ex.cpp</a> example.
</font> <font color='#BB00BB'>randomize_samples</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>cross validation: \n</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>cross_validate_multiclass_trainer</font><font face='Lucida Console'>(</font>trainer, samples, labels, <font color='#979000'>5</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#009900'>// This dataset is very easy and everything is correctly classified. Therefore, the output of
</font> <font color='#009900'>// cross validation is the following confusion matrix.
</font> <font color='#009900'>/*
50 0 0 0
0 50 0 0
0 0 50 0
0 0 0 50
*/</font>
<font color='#009900'>// We can also obtain the decision rule as always.
</font> one_vs_one_decision_function<font color='#5555FF'><</font>ovo_trainer<font color='#5555FF'>></font> df <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>predicted label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>df</font><font face='Lucida Console'>(</font>samples[<font color='#979000'>0</font>]<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>, true label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> labels[<font color='#979000'>0</font>] <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>predicted label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>df</font><font face='Lucida Console'>(</font>samples[<font color='#979000'>90</font>]<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>, true label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> labels[<font color='#979000'>90</font>] <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#009900'>// The output is:
</font> <font color='#009900'>/*
predicted label: upper_right, true label: upper_right
predicted label: lower_left, true label: lower_left
*/</font>
<font color='#009900'>// Finally, let's save our multiclass decision rule to disk. Remember that we have
</font> <font color='#009900'>// to specify the types of binary decision function used inside the one_vs_one_decision_function.
</font> one_vs_one_decision_function<font color='#5555FF'><</font>ovo_trainer,
custom_decision_function, <font color='#009900'>// This is the output of the simple_custom_trainer
</font> decision_function<font color='#5555FF'><</font>radial_basis_kernel<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font> <font color='#5555FF'>></font> <font color='#009900'>// This is the output of the rbf_trainer
</font> <font color='#5555FF'>></font> df2, df3;
df2 <font color='#5555FF'>=</font> df;
<font color='#009900'>// save to a file called df.dat
</font> <font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>df.dat</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> df2;
<font color='#009900'>// load the function back in from disk and store it in df3.
</font> <font color='#BB00BB'>deserialize</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>df.dat</font>"<font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>></font> df3;
<font color='#009900'>// Test df3 to see that this worked.
</font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>predicted label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>df3</font><font face='Lucida Console'>(</font>samples[<font color='#979000'>0</font>]<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>, true label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> labels[<font color='#979000'>0</font>] <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>predicted label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>df3</font><font face='Lucida Console'>(</font>samples[<font color='#979000'>90</font>]<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>, true label: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> labels[<font color='#979000'>90</font>] <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<font color='#009900'>// Test df3 on the samples and labels and print the confusion matrix.
</font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>test deserialized function: \n</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>test_multiclass_decision_function</font><font face='Lucida Console'>(</font>df3, samples, labels<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>void</u></font> <b><a name='generate_data'></a>generate_data</b> <font face='Lucida Console'>(</font>
std::vector<font color='#5555FF'><</font>sample_type<font color='#5555FF'>></font><font color='#5555FF'>&</font> samples,
std::vector<font color='#5555FF'><</font>string<font color='#5555FF'>></font><font color='#5555FF'>&</font> labels
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> num <font color='#5555FF'>=</font> <font color='#979000'>50</font>;
sample_type m;
dlib::rand rnd;
<font color='#009900'>// add some points in the upper right quadrant
</font> m <font color='#5555FF'>=</font> <font color='#979000'>10</font>, <font color='#979000'>10</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> num; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>m <font color='#5555FF'>+</font> <font color='#BB00BB'>randm</font><font face='Lucida Console'>(</font><font color='#979000'>2</font>,<font color='#979000'>1</font>,rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>upper_right</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// add some points in the upper left quadrant
</font> m <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>10</font>, <font color='#979000'>10</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> num; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>m <font color='#5555FF'>+</font> <font color='#BB00BB'>randm</font><font face='Lucida Console'>(</font><font color='#979000'>2</font>,<font color='#979000'>1</font>,rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>upper_left</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// add some points in the lower right quadrant
</font> m <font color='#5555FF'>=</font> <font color='#979000'>10</font>, <font color='#5555FF'>-</font><font color='#979000'>10</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> num; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>m <font color='#5555FF'>+</font> <font color='#BB00BB'>randm</font><font face='Lucida Console'>(</font><font color='#979000'>2</font>,<font color='#979000'>1</font>,rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>lower_right</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#009900'>// add some points in the lower left quadrant
</font> m <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>10</font>, <font color='#5555FF'>-</font><font color='#979000'>10</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> num; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>m <font color='#5555FF'>+</font> <font color='#BB00BB'>randm</font><font face='Lucida Console'>(</font><font color='#979000'>2</font>,<font color='#979000'>1</font>,rnd<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>lower_left</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
</pre></body></html> |