File size: 7,782 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#pragma once

#include <string>
#include <map>
#include <limits>
#include <vector>

#include <boost/unordered_map.hpp>
#include <boost/functional/hash.hpp>

#include "moses/FF/StatefulFeatureFunction.h"
#include "moses/PP/CountsPhraseProperty.h"
#include "moses/TranslationOptionList.h"
#include "moses/TranslationOption.h"
#include "moses/Util.h"
#include "moses/TypeDef.h"
#include "moses/StaticData.h"
#include "moses/Phrase.h"
#include "moses/AlignmentInfo.h"
#include "moses/Word.h"
#include "moses/FactorCollection.h"

#include "Normalizer.h"
#include "Classifier.h"
#include "VWFeatureBase.h"
#include "TabbedSentence.h"
#include "ThreadLocalByFeatureStorage.h"
#include "TrainingLoss.h"
#include "VWTargetSentence.h"

/*
 * VW classifier feature. See vw/README.md for further information.
 *
 * TODO: say which paper to cite.
 */

namespace Moses
{

// dummy class label; VW does not use the actual label, other classifiers might
const std::string VW_DUMMY_LABEL = "1111";

// thread-specific classifier instance
typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier;

// current target sentence, used in VW training (vwtrainer), not in decoding (prediction time)
typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;

// hash table of feature vectors
typedef boost::unordered_map<size_t, Discriminative::FeatureVector> FeatureVectorMap;

// thread-specific feature vector hash
typedef ThreadLocalByFeatureStorage<FeatureVectorMap> TLSFeatureVectorMap;

// hash table of partial scores
typedef boost::unordered_map<size_t, float> FloatHashMap;

// thread-specific score hash table, used for caching
typedef ThreadLocalByFeatureStorage<FloatHashMap> TLSFloatHashMap;

// thread-specific hash tablei for caching full classifier outputs
typedef ThreadLocalByFeatureStorage<boost::unordered_map<size_t, FloatHashMap> > TLSStateExtensions;

/*
 * VW feature function. A discriminative classifier with source and target context features.
 */
class VW : public StatefulFeatureFunction, public TLSTargetSentence
{
public:
  VW(const std::string &line);

  virtual ~VW();

  bool IsUseable(const FactorMask &mask) const {
    return true;
  }

  void EvaluateInIsolation(const Phrase &source
                           , const TargetPhrase &targetPhrase
                           , ScoreComponentCollection &scoreBreakdown
                           , ScoreComponentCollection &estimatedFutureScore) const {
  }

  void EvaluateWithSourceContext(const InputType &input
                                 , const InputPath &inputPath
                                 , const TargetPhrase &targetPhrase
                                 , const StackVec *stackVec
                                 , ScoreComponentCollection &scoreBreakdown
                                 , ScoreComponentCollection *estimatedFutureScore = NULL) const {
  }

  // This behavior of this method depends on whether it's called during VW
  // training (feature extraction) by vwtrainer or during decoding (prediction
  // time) by Moses.
  //
  // When predicting, it evaluates all translation options with the VW model;
  // if no target-context features are defined, this is the final score and it
  // is added directly to the TranslationOption score. If there are target
  // context features, the score is a partial score and it is only stored in
  // cache; the final score is computed based on target context in
  // EvaluateWhenApplied().
  //
  // This method is also used in training by vwtrainer in which case features
  // are written to a file, no classifier predictions take place. Target-side
  // context is constant at training time (we know the true target sentence),
  // so target-context features are extracted here as well.
  virtual void EvaluateTranslationOptionListWithSourceContext(const InputType &input
      , const TranslationOptionList &translationOptionList) const;

  // Evaluate VW during decoding. This is only used at prediction time (not in training).
  // When no target-context features are defined, VW predictions were already fully calculated
  // in EvaluateTranslationOptionListWithSourceContext() and the scores were added to the model.
  // If there are target-context features, we compute the context-dependent part of the
  // classifier score and combine it with the source-context only partial score which was computed
  // in EvaluateTranslationOptionListWithSourceContext(). Various caches are used to make this
  // method more efficient.
  virtual FFState* EvaluateWhenApplied(
    const Hypothesis& curHypo,
    const FFState* prevState,
    ScoreComponentCollection* accumulator) const;

  virtual FFState* EvaluateWhenApplied(
    const ChartHypothesis&,
    int,
    ScoreComponentCollection* accumulator) const {
    throw new std::logic_error("hiearchical/syntax not supported");
  }

  // Initial VW state; contains unaligned BOS symbols.
  const FFState* EmptyHypothesisState(const InputType &input) const;

  void SetParameter(const std::string& key, const std::string& value);

  // At prediction time, this clears our caches. At training time, we load the next sentence, its
  // translation and word alignment.
  virtual void InitializeForInput(ttasksptr const& ttask);

private:
  inline std::string MakeTargetLabel(const TargetPhrase &targetPhrase) const {
    return VW_DUMMY_LABEL; // VW does not care about class labels in our setting (--csoaa_ldf mc).
  }

  inline size_t MakeCacheKey(const FFState *prevState, size_t spanStart, size_t spanEnd) const {
    size_t key = 0;
    boost::hash_combine(key, prevState);
    boost::hash_combine(key, spanStart);
    boost::hash_combine(key, spanEnd);
    return key;
  }

  // used in decoding to transform the global word alignment information into
  // context-phrase internal alignment information (i.e., with target indices correspoding
  // to positions in contextPhrase)
  const AlignmentInfo *TransformAlignmentInfo(const Hypothesis &curHypo, size_t contextSize) const;

  // used during training to extract relevant alignment points from the full sentence alignment
  // and shift them by target context size
  AlignmentInfo TransformAlignmentInfo(const AlignmentInfo &alignInfo, size_t contextSize, int currentStart) const;

  // At training time, determine whether a translation option is correct for the current target sentence
  // based on word alignment. This is a bit complicated because we need to handle various corner-cases
  // where some word(s) on phrase borders are unaligned.
  std::pair<bool, int> IsCorrectTranslationOption(const TranslationOption &topt) const;

  // At training time, optionally discount occurrences of phrase pairs from the current sentence, helps prevent
  // over-fitting.
  std::vector<bool> LeaveOneOut(const TranslationOptionList &topts, const std::vector<bool> &correct) const;

  bool m_train; // false means predict
  std::string m_modelPath; // path to the VW model file; at training time, this is where extracted features are stored
  std::string m_vwOptions; // options for Vowpal Wabbit

  // BOS token, all factors
  Word m_sentenceStartWord;

  // calculator of training loss
  TrainingLoss *m_trainingLoss = NULL;

  // optionally contains feature name of a phrase table where we recompute scores with leaving one out
  std::string m_leaveOneOut;

  // normalizer, typically this means softmax
  Discriminative::Normalizer *m_normalizer = NULL;

  // thread-specific classifier instance
  TLSClassifier *m_tlsClassifier;

  // caches for partial scores and feature vectors
  TLSFloatHashMap *m_tlsFutureScores;
  TLSStateExtensions *m_tlsComputedStateExtensions;
  TLSFeatureVectorMap *m_tlsTranslationOptionFeatures, *m_tlsTargetContextFeatures;
};

}