|
#ifndef moses_Classifier_h |
|
#define moses_Classifier_h |
|
|
|
#include <iostream> |
|
#include <string> |
|
#include <fstream> |
|
#include <sstream> |
|
#include <deque> |
|
#include <vector> |
|
#include <boost/shared_ptr.hpp> |
|
|
|
#include <boost/noncopyable.hpp> |
|
#include <boost/thread/condition_variable.hpp> |
|
#include <boost/thread/locks.hpp> |
|
#include <boost/thread/mutex.hpp> |
|
#include <boost/iostreams/filtering_stream.hpp> |
|
#include <boost/iostreams/filter/gzip.hpp> |
|
#include "../util/string_piece.hh" |
|
#include "../moses/Util.h" |
|
|
|
|
|
struct vw; |
|
class ezexample; |
|
|
|
namespace Discriminative |
|
{ |
|
typedef std::pair<uint32_t, float> FeatureType; |
|
typedef std::vector<FeatureType> FeatureVector; |
|
|
|
|
|
|
|
|
|
class Classifier |
|
{ |
|
public: |
|
|
|
|
|
|
|
virtual FeatureType AddLabelIndependentFeature(const StringPiece &name, float value) = 0; |
|
|
|
|
|
|
|
|
|
virtual FeatureType AddLabelDependentFeature(const StringPiece &name, float value) = 0; |
|
|
|
|
|
|
|
|
|
virtual void AddLabelIndependentFeatureVector(const FeatureVector &features) = 0; |
|
|
|
|
|
|
|
|
|
virtual void AddLabelDependentFeatureVector(const FeatureVector &features) = 0; |
|
|
|
|
|
|
|
|
|
|
|
virtual void Train(const StringPiece &label, float loss) = 0; |
|
|
|
|
|
|
|
|
|
|
|
virtual float Predict(const StringPiece &label) = 0; |
|
|
|
|
|
FeatureType AddLabelIndependentFeature(const StringPiece &name) { |
|
return AddLabelIndependentFeature(name, 1.0); |
|
} |
|
|
|
FeatureType AddLabelDependentFeature(const StringPiece &name) { |
|
return AddLabelDependentFeature(name, 1.0); |
|
} |
|
|
|
virtual ~Classifier() {} |
|
|
|
protected: |
|
|
|
|
|
|
|
static std::string EscapeSpecialChars(const std::string &str) { |
|
std::string out; |
|
out = Moses::Replace(str, "\\", "_/_"); |
|
out = Moses::Replace(out, "|", "\\/"); |
|
out = Moses::Replace(out, ":", "\\;"); |
|
out = Moses::Replace(out, " ", "\\_"); |
|
return out; |
|
} |
|
|
|
const static bool DEBUG = false; |
|
|
|
}; |
|
|
|
|
|
|
|
const std::string VW_DEFAULT_OPTIONS = " --hash all --noconstant -q st -t --ldf_override sc "; |
|
const std::string VW_DEFAULT_PARSER_OPTIONS = " --quiet --hash all --noconstant -q st -t --csoaa_ldf sc "; |
|
|
|
|
|
|
|
|
|
class VWTrainer : public Classifier |
|
{ |
|
public: |
|
VWTrainer(const std::string &outputFile); |
|
virtual ~VWTrainer(); |
|
|
|
virtual FeatureType AddLabelIndependentFeature(const StringPiece &name, float value); |
|
virtual FeatureType AddLabelDependentFeature(const StringPiece &name, float value); |
|
virtual void AddLabelIndependentFeatureVector(const FeatureVector &features); |
|
virtual void AddLabelDependentFeatureVector(const FeatureVector &features); |
|
virtual void Train(const StringPiece &label, float loss); |
|
virtual float Predict(const StringPiece &label); |
|
|
|
protected: |
|
void AddFeature(const StringPiece &name, float value); |
|
|
|
bool m_isFirstSource, m_isFirstTarget, m_isFirstExample; |
|
|
|
private: |
|
boost::iostreams::filtering_ostream m_bfos; |
|
std::deque<std::string> m_outputBuffer; |
|
|
|
void WriteBuffer(); |
|
}; |
|
|
|
|
|
|
|
|
|
class VWPredictor : public Classifier, private boost::noncopyable |
|
{ |
|
public: |
|
VWPredictor(const std::string &modelFile, const std::string &vwOptions); |
|
virtual ~VWPredictor(); |
|
|
|
virtual FeatureType AddLabelIndependentFeature(const StringPiece &name, float value); |
|
virtual FeatureType AddLabelDependentFeature(const StringPiece &name, float value); |
|
virtual void AddLabelIndependentFeatureVector(const FeatureVector &features); |
|
virtual void AddLabelDependentFeatureVector(const FeatureVector &features); |
|
virtual void Train(const StringPiece &label, float loss); |
|
virtual float Predict(const StringPiece &label); |
|
|
|
friend class ClassifierFactory; |
|
|
|
protected: |
|
FeatureType AddFeature(const StringPiece &name, float values); |
|
|
|
::vw *m_VWInstance, *m_VWParser; |
|
::ezexample *m_ex; |
|
|
|
|
|
bool m_sharedVwInstance; |
|
bool m_isFirstSource, m_isFirstTarget; |
|
|
|
private: |
|
|
|
VWPredictor(vw * instance, const std::string &vwOption); |
|
}; |
|
|
|
|
|
|
|
|
|
class ClassifierFactory : private boost::noncopyable |
|
{ |
|
public: |
|
typedef boost::shared_ptr<Classifier> ClassifierPtr; |
|
|
|
|
|
|
|
|
|
ClassifierFactory(const std::string &modelFile, const std::string &vwOptions); |
|
|
|
|
|
|
|
|
|
ClassifierFactory(const std::string &modelFilePrefix); |
|
|
|
|
|
ClassifierPtr operator()(); |
|
|
|
~ClassifierFactory(); |
|
|
|
private: |
|
std::string m_vwOptions; |
|
::vw *m_VWInstance; |
|
int m_lastId; |
|
std::string m_modelFilePrefix; |
|
bool m_gzip; |
|
boost::mutex m_mutex; |
|
const bool m_train; |
|
}; |
|
|
|
} |
|
|
|
#endif |
|
|