File size: 2,325 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
#include "Util.h"
#include "Classifier.h"
#include <boost/algorithm/string/predicate.hpp>
#include <boost/iostreams/device/file.hpp>
using namespace std;
using namespace boost::algorithm;
using namespace Moses;
namespace Discriminative
{
VWTrainer::VWTrainer(const std::string &outputFile)
{
if (ends_with(outputFile, ".gz")) {
m_bfos.push(boost::iostreams::gzip_compressor());
}
m_bfos.push(boost::iostreams::file_sink(outputFile));
m_isFirstSource = m_isFirstTarget = m_isFirstExample = true;
}
VWTrainer::~VWTrainer()
{
m_bfos << "\n";
close(m_bfos);
}
FeatureType VWTrainer::AddLabelIndependentFeature(const StringPiece &name, float value)
{
if (m_isFirstSource) {
if (m_isFirstExample) {
m_isFirstExample = false;
} else {
// finish previous example
m_bfos << "\n";
}
m_isFirstSource = false;
if (! m_outputBuffer.empty())
WriteBuffer();
m_outputBuffer.push_back("shared |s");
}
AddFeature(name, value);
return std::make_pair(0, value); // we don't hash features
}
FeatureType VWTrainer::AddLabelDependentFeature(const StringPiece &name, float value)
{
if (m_isFirstTarget) {
m_isFirstTarget = false;
if (! m_outputBuffer.empty())
WriteBuffer();
m_outputBuffer.push_back("|t");
}
AddFeature(name, value);
return std::make_pair(0, value); // we don't hash features
}
void VWTrainer::AddLabelIndependentFeatureVector(const FeatureVector &features)
{
throw logic_error("VW trainer does not support feature IDs.");
}
void VWTrainer::AddLabelDependentFeatureVector(const FeatureVector &features)
{
throw logic_error("VW trainer does not support feature IDs.");
}
void VWTrainer::Train(const StringPiece &label, float loss)
{
m_outputBuffer.push_front(label.as_string() + ":" + SPrint(loss));
m_isFirstSource = true;
m_isFirstTarget = true;
WriteBuffer();
}
float VWTrainer::Predict(const StringPiece &label)
{
throw logic_error("Trying to predict during training!");
}
void VWTrainer::AddFeature(const StringPiece &name, float value)
{
m_outputBuffer.push_back(EscapeSpecialChars(name.as_string()) + ":" + SPrint(value));
}
void VWTrainer::WriteBuffer()
{
m_bfos << Join(" ", m_outputBuffer.begin(), m_outputBuffer.end()) << "\n";
m_outputBuffer.clear();
}
} // namespace Discriminative
|