|
#include <string> |
|
#include <map> |
|
#include <limits> |
|
#include <vector> |
|
|
|
#include <boost/unordered_map.hpp> |
|
#include <boost/functional/hash.hpp> |
|
|
|
#include "moses/FF/StatefulFeatureFunction.h" |
|
#include "moses/PP/CountsPhraseProperty.h" |
|
#include "moses/TranslationOptionList.h" |
|
#include "moses/TranslationOption.h" |
|
#include "moses/Util.h" |
|
#include "moses/TypeDef.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/Phrase.h" |
|
#include "moses/AlignmentInfo.h" |
|
#include "moses/AlignmentInfoCollection.h" |
|
#include "moses/Word.h" |
|
#include "moses/FactorCollection.h" |
|
|
|
#include "Normalizer.h" |
|
#include "Classifier.h" |
|
#include "VWFeatureBase.h" |
|
#include "TabbedSentence.h" |
|
#include "ThreadLocalByFeatureStorage.h" |
|
#include "TrainingLoss.h" |
|
#include "VWTargetSentence.h" |
|
#include "VWState.h" |
|
#include "VW.h" |
|
|
|
namespace Moses |
|
{ |
|
|
|
VW::VW(const std::string &line) |
|
: StatefulFeatureFunction(1, line) |
|
, TLSTargetSentence(this) |
|
, m_train(false) |
|
, m_sentenceStartWord(Word()) |
|
{ |
|
ReadParameters(); |
|
Discriminative::ClassifierFactory *classifierFactory = m_train |
|
? new Discriminative::ClassifierFactory(m_modelPath) |
|
: new Discriminative::ClassifierFactory(m_modelPath, m_vwOptions); |
|
|
|
m_tlsClassifier = new TLSClassifier(this, *classifierFactory); |
|
|
|
m_tlsFutureScores = new TLSFloatHashMap(this); |
|
m_tlsComputedStateExtensions = new TLSStateExtensions(this); |
|
m_tlsTranslationOptionFeatures = new TLSFeatureVectorMap(this); |
|
m_tlsTargetContextFeatures = new TLSFeatureVectorMap(this); |
|
|
|
if (! m_normalizer) { |
|
VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n"); |
|
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer(); |
|
} |
|
|
|
if (! m_trainingLoss) { |
|
VERBOSE(1, "VW :: Using basic 1/0 loss calculation in training.\n"); |
|
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic(); |
|
} |
|
|
|
|
|
const Factor *bosFactor = FactorCollection::Instance().AddFactor(BOS_); |
|
for (size_t i = 0; i < MAX_NUM_FACTORS; i++) |
|
m_sentenceStartWord.SetFactor(i, bosFactor); |
|
} |
|
|
|
VW::~VW() |
|
{ |
|
delete m_tlsClassifier; |
|
delete m_normalizer; |
|
|
|
} |
|
|
|
FFState* VW::EvaluateWhenApplied( |
|
const Hypothesis& curHypo, |
|
const FFState* prevState, |
|
ScoreComponentCollection* accumulator) const |
|
{ |
|
VERBOSE(3, "VW :: Evaluating translation options\n"); |
|
|
|
const VWState& prevVWState = *static_cast<const VWState *>(prevState); |
|
|
|
const std::vector<VWFeatureBase*>& contextFeatures = |
|
VWFeatureBase::GetTargetContextFeatures(GetScoreProducerDescription()); |
|
|
|
if (contextFeatures.empty()) { |
|
|
|
|
|
|
|
return new VWState(); |
|
} |
|
|
|
size_t spanStart = curHypo.GetTranslationOption().GetStartPos(); |
|
size_t spanEnd = curHypo.GetTranslationOption().GetEndPos(); |
|
|
|
|
|
size_t cacheKey = MakeCacheKey(prevState, spanStart, spanEnd); |
|
|
|
boost::unordered_map<size_t, FloatHashMap> &computedStateExtensions |
|
= *m_tlsComputedStateExtensions->GetStored(); |
|
|
|
if (computedStateExtensions.find(cacheKey) == computedStateExtensions.end()) { |
|
|
|
const TranslationOptionList *topts = |
|
curHypo.GetManager().getSntTranslationOptions()->GetTranslationOptionList(spanStart, spanEnd); |
|
|
|
const InputType& input = curHypo.GetManager().GetSource(); |
|
|
|
Discriminative::Classifier &classifier = *m_tlsClassifier->GetStored(); |
|
|
|
|
|
size_t contextHash = prevVWState.hash(); |
|
|
|
FeatureVectorMap &contextFeaturesCache = *m_tlsTargetContextFeatures->GetStored(); |
|
|
|
FeatureVectorMap::const_iterator contextIt = contextFeaturesCache.find(contextHash); |
|
if (contextIt == contextFeaturesCache.end()) { |
|
|
|
|
|
const Phrase &targetContext = prevVWState.GetPhrase(); |
|
Discriminative::FeatureVector contextVector; |
|
const AlignmentInfo *alignInfo = TransformAlignmentInfo(curHypo, targetContext.GetSize()); |
|
for(size_t i = 0; i < contextFeatures.size(); ++i) |
|
(*contextFeatures[i])(input, targetContext, *alignInfo, classifier, contextVector); |
|
|
|
contextFeaturesCache[contextHash] = contextVector; |
|
VERBOSE(3, "VW :: context cache miss\n"); |
|
} else { |
|
|
|
classifier.AddLabelIndependentFeatureVector(contextIt->second); |
|
VERBOSE(3, "VW :: context cache hit\n"); |
|
} |
|
|
|
std::vector<float> losses(topts->size()); |
|
|
|
for (size_t toptIdx = 0; toptIdx < topts->size(); toptIdx++) { |
|
const TranslationOption *topt = topts->Get(toptIdx); |
|
const TargetPhrase &targetPhrase = topt->GetTargetPhrase(); |
|
size_t toptHash = hash_value(*topt); |
|
|
|
|
|
losses[toptIdx] = m_tlsFutureScores->GetStored()->find(toptHash)->second; |
|
|
|
|
|
|
|
const Discriminative::FeatureVector &targetFeatureVector = |
|
m_tlsTranslationOptionFeatures->GetStored()->find(toptHash)->second; |
|
|
|
classifier.AddLabelDependentFeatureVector(targetFeatureVector); |
|
|
|
|
|
losses[toptIdx] += classifier.Predict(MakeTargetLabel(targetPhrase)); |
|
} |
|
|
|
|
|
(*m_normalizer)(losses); |
|
|
|
|
|
FloatHashMap &toptScores = computedStateExtensions[cacheKey]; |
|
for (size_t toptIdx = 0; toptIdx < topts->size(); toptIdx++) { |
|
const TranslationOption *topt = topts->Get(toptIdx); |
|
size_t toptHash = hash_value(*topt); |
|
toptScores[toptHash] = FloorScore(TransformScore(losses[toptIdx])); |
|
} |
|
|
|
VERBOSE(3, "VW :: cache miss\n"); |
|
} else { |
|
VERBOSE(3, "VW :: cache hit\n"); |
|
} |
|
|
|
|
|
std::vector<float> newScores(m_numScoreComponents); |
|
size_t toptHash = hash_value(curHypo.GetTranslationOption()); |
|
newScores[0] = computedStateExtensions[cacheKey][toptHash]; |
|
VERBOSE(3, "VW :: adding score: " << newScores[0] << "\n"); |
|
accumulator->PlusEquals(this, newScores); |
|
|
|
return new VWState(prevVWState, curHypo); |
|
} |
|
|
|
const FFState* VW::EmptyHypothesisState(const InputType &input) const |
|
{ |
|
size_t maxContextSize = VWFeatureBase::GetMaximumContextSize(GetScoreProducerDescription()); |
|
Phrase initialPhrase; |
|
for (size_t i = 0; i < maxContextSize; i++) |
|
initialPhrase.AddWord(m_sentenceStartWord); |
|
|
|
return new VWState(initialPhrase); |
|
} |
|
|
|
void VW::EvaluateTranslationOptionListWithSourceContext(const InputType &input |
|
, const TranslationOptionList &translationOptionList) const |
|
{ |
|
Discriminative::Classifier &classifier = *m_tlsClassifier->GetStored(); |
|
|
|
if (translationOptionList.size() == 0) |
|
return; |
|
|
|
VERBOSE(3, "VW :: Evaluating translation options\n"); |
|
|
|
|
|
const std::vector<VWFeatureBase*>& sourceFeatures = |
|
VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription()); |
|
|
|
const std::vector<VWFeatureBase*>& contextFeatures = |
|
VWFeatureBase::GetTargetContextFeatures(GetScoreProducerDescription()); |
|
|
|
const std::vector<VWFeatureBase*>& targetFeatures = |
|
VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription()); |
|
|
|
size_t maxContextSize = VWFeatureBase::GetMaximumContextSize(GetScoreProducerDescription()); |
|
|
|
|
|
bool haveTargetContextFeatures = ! contextFeatures.empty(); |
|
|
|
const Range &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange(); |
|
|
|
if (m_train) { |
|
|
|
|
|
|
|
|
|
|
|
std::vector<bool> correct(translationOptionList.size()); |
|
std::vector<int> startsAt(translationOptionList.size()); |
|
std::set<int> uncoveredStartingPositions; |
|
|
|
for (size_t i = 0; i < translationOptionList.size(); i++) { |
|
std::pair<bool, int> isCorrect = IsCorrectTranslationOption(* translationOptionList.Get(i)); |
|
correct[i] = isCorrect.first; |
|
startsAt[i] = isCorrect.second; |
|
if (isCorrect.first) { |
|
uncoveredStartingPositions.insert(isCorrect.second); |
|
} |
|
} |
|
|
|
|
|
std::vector<bool> keep = (m_leaveOneOut.size() > 0) |
|
? LeaveOneOut(translationOptionList, correct) |
|
: std::vector<bool>(translationOptionList.size(), true); |
|
|
|
while (! uncoveredStartingPositions.empty()) { |
|
int currentStart = *uncoveredStartingPositions.begin(); |
|
uncoveredStartingPositions.erase(uncoveredStartingPositions.begin()); |
|
|
|
|
|
int firstCorrect = -1; |
|
for (size_t i = 0; i < translationOptionList.size(); i++) { |
|
if (keep[i] && correct[i] && startsAt[i] == currentStart) { |
|
firstCorrect = i; |
|
break; |
|
} |
|
} |
|
|
|
|
|
if (firstCorrect == -1) { |
|
VERBOSE(3, "VW :: skipping topt collection, no correct translation for span at current tgt start position\n"); |
|
continue; |
|
} |
|
|
|
|
|
const TargetPhrase &correctPhrase = translationOptionList.Get(firstCorrect)->GetTargetPhrase(); |
|
|
|
|
|
|
|
Discriminative::FeatureVector dummyVector; |
|
|
|
|
|
for(size_t i = 0; i < sourceFeatures.size(); ++i) |
|
(*sourceFeatures[i])(input, sourceRange, classifier, dummyVector); |
|
|
|
|
|
Phrase targetContext; |
|
for (size_t i = 0; i < maxContextSize; i++) |
|
targetContext.AddWord(m_sentenceStartWord); |
|
|
|
const Phrase *targetSent = GetStored()->m_sentence; |
|
|
|
|
|
AlignmentInfo contextAlignment = TransformAlignmentInfo(*GetStored()->m_alignment, maxContextSize, currentStart); |
|
|
|
if (currentStart > 0) |
|
targetContext.Append(targetSent->GetSubString(Range(0, currentStart - 1))); |
|
|
|
|
|
for(size_t i = 0; i < contextFeatures.size(); ++i) |
|
(*contextFeatures[i])(input, targetContext, contextAlignment, classifier, dummyVector); |
|
|
|
|
|
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) { |
|
|
|
|
|
if (! keep[toptIdx]) |
|
continue; |
|
|
|
|
|
const TargetPhrase &targetPhrase = translationOptionList.Get(toptIdx)->GetTargetPhrase(); |
|
for(size_t i = 0; i < targetFeatures.size(); ++i) |
|
(*targetFeatures[i])(input, targetPhrase, classifier, dummyVector); |
|
|
|
bool isCorrect = correct[toptIdx] && startsAt[toptIdx] == currentStart; |
|
float loss = (*m_trainingLoss)(targetPhrase, correctPhrase, isCorrect); |
|
|
|
|
|
classifier.Train(MakeTargetLabel(targetPhrase), loss); |
|
} |
|
} |
|
} else { |
|
|
|
|
|
|
|
|
|
std::vector<float> losses(translationOptionList.size()); |
|
|
|
Discriminative::FeatureVector outFeaturesSourceNamespace; |
|
|
|
|
|
for(size_t i = 0; i < sourceFeatures.size(); ++i) |
|
(*sourceFeatures[i])(input, sourceRange, classifier, outFeaturesSourceNamespace); |
|
|
|
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) { |
|
const TranslationOption *topt = translationOptionList.Get(toptIdx); |
|
const TargetPhrase &targetPhrase = topt->GetTargetPhrase(); |
|
Discriminative::FeatureVector outFeaturesTargetNamespace; |
|
|
|
|
|
for(size_t i = 0; i < targetFeatures.size(); ++i) |
|
(*targetFeatures[i])(input, targetPhrase, classifier, outFeaturesTargetNamespace); |
|
|
|
|
|
|
|
size_t toptHash = hash_value(*topt); |
|
m_tlsTranslationOptionFeatures->GetStored()->insert( |
|
std::make_pair(toptHash, outFeaturesTargetNamespace)); |
|
|
|
|
|
losses[toptIdx] = classifier.Predict(MakeTargetLabel(targetPhrase)); |
|
} |
|
|
|
|
|
std::vector<float> rawLosses = losses; |
|
(*m_normalizer)(losses); |
|
|
|
|
|
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) { |
|
TranslationOption *topt = *(translationOptionList.begin() + toptIdx); |
|
if (! haveTargetContextFeatures) { |
|
|
|
std::vector<float> newScores(m_numScoreComponents); |
|
newScores[0] = FloorScore(TransformScore(losses[toptIdx])); |
|
|
|
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown(); |
|
scoreBreakDown.PlusEquals(this, newScores); |
|
|
|
topt->UpdateScore(); |
|
} else { |
|
|
|
|
|
size_t toptHash = hash_value(*topt); |
|
|
|
|
|
|
|
Discriminative::FeatureVector emptySource; |
|
const Discriminative::FeatureVector &targetFeatureVector = |
|
m_tlsTranslationOptionFeatures->GetStored()->find(toptHash)->second; |
|
classifier.AddLabelIndependentFeatureVector(emptySource); |
|
classifier.AddLabelDependentFeatureVector(targetFeatureVector); |
|
float targetOnlyLoss = classifier.Predict(VW_DUMMY_LABEL); |
|
|
|
float futureScore = rawLosses[toptIdx] - targetOnlyLoss; |
|
m_tlsFutureScores->GetStored()->insert(std::make_pair(toptHash, futureScore)); |
|
} |
|
} |
|
} |
|
} |
|
|
|
void VW::SetParameter(const std::string& key, const std::string& value) |
|
{ |
|
if (key == "train") { |
|
m_train = Scan<bool>(value); |
|
} else if (key == "path") { |
|
m_modelPath = value; |
|
} else if (key == "vw-options") { |
|
m_vwOptions = value; |
|
} else if (key == "leave-one-out-from") { |
|
m_leaveOneOut = value; |
|
} else if (key == "training-loss") { |
|
|
|
if (value == "basic") { |
|
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic(); |
|
} else if (value == "bleu") { |
|
m_trainingLoss = (TrainingLoss *) new TrainingLossBLEU(); |
|
} else { |
|
UTIL_THROW2("Unknown training loss type:" << value); |
|
} |
|
} else if (key == "loss") { |
|
|
|
|
|
if (value == "logistic") { |
|
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer(); |
|
} else if (value == "squared") { |
|
m_normalizer = (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer(); |
|
} else { |
|
UTIL_THROW2("Unknown loss type:" << value); |
|
} |
|
} else { |
|
StatefulFeatureFunction::SetParameter(key, value); |
|
} |
|
} |
|
|
|
void VW::InitializeForInput(ttasksptr const& ttask) |
|
{ |
|
|
|
m_tlsFutureScores->GetStored()->clear(); |
|
|
|
|
|
m_tlsComputedStateExtensions->GetStored()->clear(); |
|
|
|
|
|
|
|
|
|
|
|
m_tlsTargetContextFeatures->GetStored()->clear(); |
|
m_tlsTranslationOptionFeatures->GetStored()->clear(); |
|
|
|
InputType const& source = *(ttask->GetSource().get()); |
|
|
|
if (! m_train) |
|
return; |
|
|
|
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput, |
|
"This feature function requires the TabbedSentence input type"); |
|
|
|
const TabbedSentence& tabbedSentence = static_cast<const TabbedSentence&>(source); |
|
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, |
|
"TabbedSentence must contain target<tab>alignment"); |
|
|
|
|
|
Phrase *target = new Phrase(); |
|
target->CreateFromString( |
|
Output |
|
, StaticData::Instance().options()->output.factor_order |
|
, tabbedSentence.GetColumns()[0] |
|
, NULL); |
|
|
|
|
|
|
|
|
|
AlignmentInfo *alignment = new AlignmentInfo(tabbedSentence.GetColumns()[1]); |
|
|
|
VWTargetSentence &targetSent = *GetStored(); |
|
targetSent.Clear(); |
|
targetSent.m_sentence = target; |
|
targetSent.m_alignment = alignment; |
|
|
|
|
|
targetSent.SetConstraints(source.GetSize()); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
const AlignmentInfo *VW::TransformAlignmentInfo(const Hypothesis &curHypo, size_t contextSize) const |
|
{ |
|
std::set<std::pair<size_t, size_t> > alignmentPoints; |
|
const Hypothesis *contextHypo = curHypo.GetPrevHypo(); |
|
int idxInContext = contextSize - 1; |
|
int processedWordsInHypo = 0; |
|
while (idxInContext >= 0 && contextHypo) { |
|
int idxInHypo = contextHypo->GetCurrTargetLength() - 1 - processedWordsInHypo; |
|
if (idxInHypo >= 0) { |
|
const AlignmentInfo &hypoAlign = contextHypo->GetCurrTargetPhrase().GetAlignTerm(); |
|
std::set<size_t> alignedToTgt = hypoAlign.GetAlignmentsForTarget(idxInHypo); |
|
size_t srcOffset = contextHypo->GetCurrSourceWordsRange().GetStartPos(); |
|
BOOST_FOREACH(size_t srcIdx, alignedToTgt) { |
|
alignmentPoints.insert(std::make_pair(srcOffset + srcIdx, idxInContext)); |
|
} |
|
processedWordsInHypo++; |
|
idxInContext--; |
|
} else { |
|
processedWordsInHypo = 0; |
|
contextHypo = contextHypo->GetPrevHypo(); |
|
} |
|
} |
|
|
|
return AlignmentInfoCollection::Instance().Add(alignmentPoints); |
|
} |
|
|
|
AlignmentInfo VW::TransformAlignmentInfo(const AlignmentInfo &alignInfo, size_t contextSize, int currentStart) const |
|
{ |
|
std::set<std::pair<size_t, size_t> > alignmentPoints; |
|
for (int i = std::max(0, currentStart - (int)contextSize); i < currentStart; i++) { |
|
std::set<size_t> alignedToTgt = alignInfo.GetAlignmentsForTarget(i); |
|
BOOST_FOREACH(size_t srcIdx, alignedToTgt) { |
|
alignmentPoints.insert(std::make_pair(srcIdx, i + contextSize)); |
|
} |
|
} |
|
return AlignmentInfo(alignmentPoints); |
|
} |
|
|
|
std::pair<bool, int> VW::IsCorrectTranslationOption(const TranslationOption &topt) const |
|
{ |
|
|
|
|
|
|
|
int sourceStart = topt.GetSourceWordsRange().GetStartPos(); |
|
int sourceEnd = topt.GetSourceWordsRange().GetEndPos(); |
|
|
|
const VWTargetSentence &targetSentence = *GetStored(); |
|
|
|
|
|
int targetStart = targetSentence.m_sentence->GetSize(); |
|
int targetEnd = -1; |
|
|
|
|
|
for(int i = sourceStart; i <= sourceEnd; ++i) { |
|
if(targetSentence.m_sourceConstraints[i].IsSet()) { |
|
if(targetStart > targetSentence.m_sourceConstraints[i].GetMin()) |
|
targetStart = targetSentence.m_sourceConstraints[i].GetMin(); |
|
if(targetEnd < targetSentence.m_sourceConstraints[i].GetMax()) |
|
targetEnd = targetSentence.m_sourceConstraints[i].GetMax(); |
|
} |
|
} |
|
|
|
if(targetEnd == -1) |
|
return std::make_pair(false, -1); |
|
|
|
|
|
|
|
|
|
int targetStart2 = targetStart; |
|
for(int i = targetStart2; i >= 0 && !targetSentence.m_targetConstraints[i].IsSet(); --i) |
|
targetStart2 = i; |
|
|
|
int targetEnd2 = targetEnd; |
|
for(int i = targetEnd2; |
|
i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet(); |
|
++i) |
|
targetEnd2 = i; |
|
|
|
|
|
|
|
const TargetPhrase &tphrase = topt.GetTargetPhrase(); |
|
|
|
|
|
|
|
if(tphrase.GetSize() < targetEnd - targetStart + 1) |
|
return std::make_pair(false, -1); |
|
|
|
|
|
if(tphrase.GetSize() > targetEnd2 - targetStart2 + 1) |
|
return std::make_pair(false, -1); |
|
|
|
|
|
for(int tempStart = targetStart2; tempStart <= targetStart; tempStart++) { |
|
bool found = true; |
|
|
|
for(int i = tempStart; i <= targetEnd2 && i < tphrase.GetSize() + tempStart; ++i) { |
|
if(tphrase.GetWord(i - tempStart) != targetSentence.m_sentence->GetWord(i)) { |
|
found = false; |
|
break; |
|
} |
|
} |
|
|
|
if(found) { |
|
|
|
return std::make_pair(true, tempStart); |
|
} |
|
} |
|
|
|
return std::make_pair(false, -1); |
|
} |
|
|
|
std::vector<bool> VW::LeaveOneOut(const TranslationOptionList &topts, const std::vector<bool> &correct) const |
|
{ |
|
UTIL_THROW_IF2(m_leaveOneOut.size() == 0 || ! m_train, "LeaveOneOut called in wrong setting!"); |
|
|
|
float sourceRawCount = 0.0; |
|
const float ONE = 1.0001; |
|
|
|
std::vector<bool> keepOpt; |
|
|
|
for (size_t i = 0; i < topts.size(); i++) { |
|
TranslationOption *topt = *(topts.begin() + i); |
|
const TargetPhrase &targetPhrase = topt->GetTargetPhrase(); |
|
|
|
|
|
const CountsPhraseProperty *property = |
|
static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts")); |
|
|
|
if (! property) { |
|
VERBOSE(2, "VW :: Counts not found for topt! Is this an OOV?\n"); |
|
|
|
keepOpt.assign(topts.size(), true); |
|
return keepOpt; |
|
} |
|
|
|
if (sourceRawCount == 0.0) { |
|
sourceRawCount = property->GetSourceMarginal() - ONE; |
|
if (sourceRawCount <= 0) { |
|
|
|
keepOpt.assign(topts.size(), false); |
|
return keepOpt; |
|
} |
|
} |
|
|
|
float discount = correct[i] ? ONE : 0.0; |
|
float target = property->GetTargetMarginal() - discount; |
|
float joint = property->GetJointCount() - discount; |
|
if (discount != 0.0) VERBOSE(3, "VW :: leaving one out!\n"); |
|
|
|
if (joint > 0) { |
|
|
|
const FeatureFunction *feature = &FindFeatureFunction(m_leaveOneOut); |
|
std::vector<float> scores = targetPhrase.GetScoreBreakdown().GetScoresForProducer(feature); |
|
UTIL_THROW_IF2(scores.size() != 4, "Unexpected number of scores in feature " << m_leaveOneOut); |
|
scores[0] = TransformScore(joint / target); |
|
scores[2] = TransformScore(joint / sourceRawCount); |
|
|
|
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown(); |
|
scoreBreakDown.Assign(feature, scores); |
|
topt->UpdateScore(); |
|
keepOpt.push_back(true); |
|
} else { |
|
|
|
VERBOSE(2, "VW :: discarded topt when leaving one out\n"); |
|
keepOpt.push_back(false); |
|
} |
|
} |
|
|
|
return keepOpt; |
|
} |
|
|
|
} |
|
|