|
|
|
|
|
#ifndef DLIB_AnY_TRAINER_H_ |
|
#define DLIB_AnY_TRAINER_H_ |
|
|
|
#include "any.h" |
|
|
|
#include "any_decision_function.h" |
|
|
|
#include "any_trainer_abstract.h" |
|
#include <vector> |
|
|
|
namespace dlib |
|
{ |
|
|
|
|
|
|
|
template < |
|
typename sample_type_, |
|
typename scalar_type_ = double |
|
> |
|
class any_trainer |
|
{ |
|
public: |
|
typedef sample_type_ sample_type; |
|
typedef scalar_type_ scalar_type; |
|
typedef default_memory_manager mem_manager_type; |
|
typedef any_decision_function<sample_type, scalar_type> trained_function_type; |
|
|
|
|
|
any_trainer() |
|
{ |
|
} |
|
|
|
any_trainer ( |
|
const any_trainer& item |
|
) |
|
{ |
|
if (item.data) |
|
{ |
|
item.data->copy_to(data); |
|
} |
|
} |
|
|
|
template <typename T> |
|
any_trainer ( |
|
const T& item |
|
) |
|
{ |
|
typedef typename basic_type<T>::type U; |
|
data.reset(new derived<U>(item)); |
|
} |
|
|
|
void clear ( |
|
) |
|
{ |
|
data.reset(); |
|
} |
|
|
|
template <typename T> |
|
bool contains ( |
|
) const |
|
{ |
|
typedef typename basic_type<T>::type U; |
|
return dynamic_cast<derived<U>*>(data.get()) != 0; |
|
} |
|
|
|
bool is_empty( |
|
) const |
|
{ |
|
return data.get() == 0; |
|
} |
|
|
|
trained_function_type train ( |
|
const std::vector<sample_type>& samples, |
|
const std::vector<scalar_type>& labels |
|
) const |
|
{ |
|
|
|
DLIB_ASSERT(is_empty() == false, |
|
"\t trained_function_type any_trainer::train()" |
|
<< "\n\t You can't call train() on an empty any_trainer" |
|
<< "\n\t this: " << this |
|
); |
|
|
|
return data->train(samples, labels); |
|
} |
|
|
|
template <typename T> |
|
T& cast_to( |
|
) |
|
{ |
|
typedef typename basic_type<T>::type U; |
|
derived<U>* d = dynamic_cast<derived<U>*>(data.get()); |
|
if (d == 0) |
|
{ |
|
throw bad_any_cast(); |
|
} |
|
|
|
return d->item; |
|
} |
|
|
|
template <typename T> |
|
const T& cast_to( |
|
) const |
|
{ |
|
typedef typename basic_type<T>::type U; |
|
derived<U>* d = dynamic_cast<derived<U>*>(data.get()); |
|
if (d == 0) |
|
{ |
|
throw bad_any_cast(); |
|
} |
|
|
|
return d->item; |
|
} |
|
|
|
template <typename T> |
|
T& get( |
|
) |
|
{ |
|
typedef typename basic_type<T>::type U; |
|
derived<U>* d = dynamic_cast<derived<U>*>(data.get()); |
|
if (d == 0) |
|
{ |
|
d = new derived<U>(); |
|
data.reset(d); |
|
} |
|
|
|
return d->item; |
|
} |
|
|
|
any_trainer& operator= ( |
|
const any_trainer& item |
|
) |
|
{ |
|
any_trainer(item).swap(*this); |
|
return *this; |
|
} |
|
|
|
void swap ( |
|
any_trainer& item |
|
) |
|
{ |
|
data.swap(item.data); |
|
} |
|
|
|
private: |
|
|
|
struct base |
|
{ |
|
virtual ~base() {} |
|
|
|
virtual trained_function_type train ( |
|
const std::vector<sample_type>& samples, |
|
const std::vector<scalar_type>& labels |
|
) const = 0; |
|
|
|
virtual void copy_to ( |
|
std::unique_ptr<base>& dest |
|
) const = 0; |
|
}; |
|
|
|
template <typename T> |
|
struct derived : public base |
|
{ |
|
T item; |
|
derived() {} |
|
derived(const T& val) : item(val) {} |
|
|
|
virtual void copy_to ( |
|
std::unique_ptr<base>& dest |
|
) const |
|
{ |
|
dest.reset(new derived<T>(item)); |
|
} |
|
|
|
virtual trained_function_type train ( |
|
const std::vector<sample_type>& samples, |
|
const std::vector<scalar_type>& labels |
|
) const |
|
{ |
|
return item.train(samples, labels); |
|
} |
|
}; |
|
|
|
std::unique_ptr<base> data; |
|
}; |
|
|
|
|
|
|
|
template < |
|
typename sample_type, |
|
typename scalar_type |
|
> |
|
inline void swap ( |
|
any_trainer<sample_type,scalar_type>& a, |
|
any_trainer<sample_type,scalar_type>& b |
|
) { a.swap(b); } |
|
|
|
|
|
|
|
template <typename T, typename U, typename V> |
|
T& any_cast(any_trainer<U,V>& a) { return a.template cast_to<T>(); } |
|
|
|
template <typename T, typename U, typename V> |
|
const T& any_cast(const any_trainer<U,V>& a) { return a.template cast_to<T>(); } |
|
|
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|