AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2010 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_AnY_TRAINER_H_
#define DLIB_AnY_TRAINER_H_
#include "any.h"
#include "any_decision_function.h"
#include "any_trainer_abstract.h"
#include <vector>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename sample_type_,
typename scalar_type_ = double
>
class any_trainer
{
public:
typedef sample_type_ sample_type;
typedef scalar_type_ scalar_type;
typedef default_memory_manager mem_manager_type;
typedef any_decision_function<sample_type, scalar_type> trained_function_type;
any_trainer()
{
}
any_trainer (
const any_trainer& item
)
{
if (item.data)
{
item.data->copy_to(data);
}
}
template <typename T>
any_trainer (
const T& item
)
{
typedef typename basic_type<T>::type U;
data.reset(new derived<U>(item));
}
void clear (
)
{
data.reset();
}
template <typename T>
bool contains (
) const
{
typedef typename basic_type<T>::type U;
return dynamic_cast<derived<U>*>(data.get()) != 0;
}
bool is_empty(
) const
{
return data.get() == 0;
}
trained_function_type train (
const std::vector<sample_type>& samples,
const std::vector<scalar_type>& labels
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_empty() == false,
"\t trained_function_type any_trainer::train()"
<< "\n\t You can't call train() on an empty any_trainer"
<< "\n\t this: " << this
);
return data->train(samples, labels);
}
template <typename T>
T& cast_to(
)
{
typedef typename basic_type<T>::type U;
derived<U>* d = dynamic_cast<derived<U>*>(data.get());
if (d == 0)
{
throw bad_any_cast();
}
return d->item;
}
template <typename T>
const T& cast_to(
) const
{
typedef typename basic_type<T>::type U;
derived<U>* d = dynamic_cast<derived<U>*>(data.get());
if (d == 0)
{
throw bad_any_cast();
}
return d->item;
}
template <typename T>
T& get(
)
{
typedef typename basic_type<T>::type U;
derived<U>* d = dynamic_cast<derived<U>*>(data.get());
if (d == 0)
{
d = new derived<U>();
data.reset(d);
}
return d->item;
}
any_trainer& operator= (
const any_trainer& item
)
{
any_trainer(item).swap(*this);
return *this;
}
void swap (
any_trainer& item
)
{
data.swap(item.data);
}
private:
struct base
{
virtual ~base() {}
virtual trained_function_type train (
const std::vector<sample_type>& samples,
const std::vector<scalar_type>& labels
) const = 0;
virtual void copy_to (
std::unique_ptr<base>& dest
) const = 0;
};
template <typename T>
struct derived : public base
{
T item;
derived() {}
derived(const T& val) : item(val) {}
virtual void copy_to (
std::unique_ptr<base>& dest
) const
{
dest.reset(new derived<T>(item));
}
virtual trained_function_type train (
const std::vector<sample_type>& samples,
const std::vector<scalar_type>& labels
) const
{
return item.train(samples, labels);
}
};
std::unique_ptr<base> data;
};
// ----------------------------------------------------------------------------------------
template <
typename sample_type,
typename scalar_type
>
inline void swap (
any_trainer<sample_type,scalar_type>& a,
any_trainer<sample_type,scalar_type>& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
template <typename T, typename U, typename V>
T& any_cast(any_trainer<U,V>& a) { return a.template cast_to<T>(); }
template <typename T, typename U, typename V>
const T& any_cast(const any_trainer<U,V>& a) { return a.template cast_to<T>(); }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_AnY_TRAINER_H_