File size: 1,323 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
#include "Classifier.h"
#include "vw.h"
#include "../moses/Util.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace boost::algorithm;
namespace Discriminative
{
ClassifierFactory::ClassifierFactory(const std::string &modelFile, const std::string &vwOptions)
: m_vwOptions(vwOptions), m_train(false)
{
m_VWInstance = VW::initialize(VW_DEFAULT_OPTIONS + " -i " + modelFile + vwOptions);
}
ClassifierFactory::ClassifierFactory(const std::string &modelFilePrefix)
: m_lastId(0), m_train(true)
{
if (ends_with(modelFilePrefix, ".gz")) {
m_modelFilePrefix = modelFilePrefix.substr(0, modelFilePrefix.size() - 3);
m_gzip = true;
} else {
m_modelFilePrefix = modelFilePrefix;
m_gzip = false;
}
}
ClassifierFactory::~ClassifierFactory()
{
if (! m_train)
VW::finish(*m_VWInstance);
}
ClassifierFactory::ClassifierPtr ClassifierFactory::operator()()
{
if (m_train) {
boost::unique_lock<boost::mutex> lock(m_mutex); // avoid possible race for m_lastId
return ClassifierFactory::ClassifierPtr(
new VWTrainer(m_modelFilePrefix + "." + Moses::SPrint(m_lastId++) + (m_gzip ? ".gz" : "")));
} else {
return ClassifierFactory::ClassifierPtr(
new VWPredictor(m_VWInstance, VW_DEFAULT_PARSER_OPTIONS + m_vwOptions));
}
}
}
|