File size: 1,323 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include "Classifier.h"
#include "vw.h"
#include "../moses/Util.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>

using namespace boost::algorithm;

namespace Discriminative
{

ClassifierFactory::ClassifierFactory(const std::string &modelFile, const std::string &vwOptions)
  : m_vwOptions(vwOptions), m_train(false)
{
  m_VWInstance = VW::initialize(VW_DEFAULT_OPTIONS + " -i " + modelFile + vwOptions);
}

ClassifierFactory::ClassifierFactory(const std::string &modelFilePrefix)
  : m_lastId(0), m_train(true)
{
  if (ends_with(modelFilePrefix, ".gz")) {
    m_modelFilePrefix = modelFilePrefix.substr(0, modelFilePrefix.size() - 3);
    m_gzip = true;
  } else {
    m_modelFilePrefix = modelFilePrefix;
    m_gzip = false;
  }
}

ClassifierFactory::~ClassifierFactory()
{
  if (! m_train)
    VW::finish(*m_VWInstance);
}

ClassifierFactory::ClassifierPtr ClassifierFactory::operator()()
{
  if (m_train) {
    boost::unique_lock<boost::mutex> lock(m_mutex); // avoid possible race for m_lastId
    return ClassifierFactory::ClassifierPtr(
             new VWTrainer(m_modelFilePrefix + "." + Moses::SPrint(m_lastId++) + (m_gzip ? ".gz" : "")));
  } else {
    return ClassifierFactory::ClassifierPtr(
             new VWPredictor(m_VWInstance, VW_DEFAULT_PARSER_OPTIONS + m_vwOptions));
  }
}

}