|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <limits> |
|
#include <iostream> |
|
#include <memory> |
|
#include <sstream> |
|
|
|
#include "moses/FF/FFState.h" |
|
#include "Implementation.h" |
|
#include "ChartState.h" |
|
#include "moses/TypeDef.h" |
|
#include "moses/Util.h" |
|
#include "moses/Manager.h" |
|
#include "moses/FactorCollection.h" |
|
#include "moses/Phrase.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/ChartManager.h" |
|
#include "moses/ChartHypothesis.h" |
|
#include "util/exception.hh" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
LanguageModelImplementation::LanguageModelImplementation(const std::string &line) |
|
:LanguageModel(line) |
|
,m_nGramOrder(NOT_FOUND) |
|
{ |
|
} |
|
|
|
void LanguageModelImplementation::SetParameter(const std::string& key, const std::string& value) |
|
{ |
|
if (key == "order") { |
|
m_nGramOrder = Scan<size_t>(value); |
|
} else if (key == "path") { |
|
m_filePath = value; |
|
} else { |
|
LanguageModel::SetParameter(key, value); |
|
} |
|
|
|
} |
|
|
|
void LanguageModelImplementation::ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const |
|
{ |
|
if (contextFactor.size() < GetNGramOrder()) { |
|
contextFactor.push_back(&word); |
|
} else if (GetNGramOrder() > 0) { |
|
|
|
for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++) { |
|
contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1]; |
|
} |
|
contextFactor[GetNGramOrder() - 1] = &word; |
|
} |
|
} |
|
|
|
LMResult LanguageModelImplementation::GetValueGivenState( |
|
const std::vector<const Word*> &contextFactor, |
|
FFState &state) const |
|
{ |
|
return GetValueForgotState(contextFactor, state); |
|
} |
|
|
|
void LanguageModelImplementation::GetState( |
|
const std::vector<const Word*> &contextFactor, |
|
FFState &state) const |
|
{ |
|
GetValueForgotState(contextFactor, state); |
|
} |
|
|
|
|
|
void LanguageModelImplementation::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const |
|
{ |
|
fullScore = 0; |
|
ngramScore = 0; |
|
|
|
oovCount = 0; |
|
|
|
size_t phraseSize = phrase.GetSize(); |
|
if (!phraseSize) return; |
|
|
|
vector<const Word*> contextFactor; |
|
contextFactor.reserve(GetNGramOrder()); |
|
std::auto_ptr<FFState> state(NewState((phrase.GetWord(0) == GetSentenceStartWord()) ? |
|
GetBeginSentenceState() : GetNullContextState())); |
|
size_t currPos = 0; |
|
while (currPos < phraseSize) { |
|
const Word &word = phrase.GetWord(currPos); |
|
|
|
if (word.IsNonTerminal()) { |
|
|
|
if (!contextFactor.empty()) { |
|
|
|
state.reset(NewState(GetNullContextState())); |
|
contextFactor.clear(); |
|
} |
|
} else { |
|
ShiftOrPush(contextFactor, word); |
|
UTIL_THROW_IF2(contextFactor.size() > GetNGramOrder(), |
|
"Can only calculate LM score of phrases up to the n-gram order"); |
|
|
|
if (word == GetSentenceStartWord()) { |
|
|
|
if (currPos != 0) { |
|
UTIL_THROW2("Either your data contains <s> in a position other than the first word or your language model is missing <s>. Did you build your ARPA using IRSTLM and forget to run add-start-end.sh?"); |
|
} |
|
} else { |
|
LMResult result = GetValueGivenState(contextFactor, *state); |
|
fullScore += result.score; |
|
if (contextFactor.size() == GetNGramOrder()) |
|
ngramScore += result.score; |
|
if (result.unknown) ++oovCount; |
|
} |
|
} |
|
|
|
currPos++; |
|
} |
|
} |
|
|
|
FFState *LanguageModelImplementation::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
if(GetNGramOrder() <= 1) |
|
return NULL; |
|
|
|
|
|
if (hypo.GetCurrTargetLength() == 0) |
|
return ps ? NewState(ps) : NULL; |
|
|
|
IFVERBOSE(2) { |
|
hypo.GetManager().GetSentenceStats().StartTimeCalcLM(); |
|
} |
|
|
|
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos(); |
|
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos(); |
|
|
|
|
|
vector<const Word*> contextFactor(GetNGramOrder()); |
|
size_t index = 0; |
|
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) { |
|
if (currPos >= 0) |
|
contextFactor[index++] = &hypo.GetWord(currPos); |
|
else { |
|
contextFactor[index++] = &GetSentenceStartWord(); |
|
} |
|
} |
|
FFState *res = NewState(ps); |
|
float lmScore = ps ? GetValueGivenState(contextFactor, *res).score : GetValueForgotState(contextFactor, *res).score; |
|
|
|
|
|
size_t endPos = std::min(startPos + GetNGramOrder() - 2 |
|
, currEndPos); |
|
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) { |
|
|
|
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++) |
|
contextFactor[i] = contextFactor[i + 1]; |
|
|
|
|
|
contextFactor.back() = &hypo.GetWord(currPos); |
|
|
|
lmScore += GetValueGivenState(contextFactor, *res).score; |
|
} |
|
|
|
|
|
if (hypo.IsSourceCompleted()) { |
|
const size_t size = hypo.GetSize(); |
|
contextFactor.back() = &GetSentenceEndWord(); |
|
|
|
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) { |
|
int currPos = (int)(size - GetNGramOrder() + i + 1); |
|
if (currPos < 0) |
|
contextFactor[i] = &GetSentenceStartWord(); |
|
else |
|
contextFactor[i] = &hypo.GetWord((size_t)currPos); |
|
} |
|
lmScore += GetValueForgotState(contextFactor, *res).score; |
|
} else { |
|
if (endPos < currEndPos) { |
|
|
|
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) { |
|
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++) |
|
contextFactor[i] = contextFactor[i + 1]; |
|
contextFactor.back() = &hypo.GetWord(currPos); |
|
} |
|
GetState(contextFactor, *res); |
|
} |
|
} |
|
if (OOVFeatureEnabled()) { |
|
vector<float> scores(2); |
|
scores[0] = lmScore; |
|
scores[1] = 0; |
|
out->PlusEquals(this, scores); |
|
} else { |
|
out->PlusEquals(this, lmScore); |
|
} |
|
|
|
IFVERBOSE(2) { |
|
hypo.GetManager().GetSentenceStats().StopTimeCalcLM(); |
|
} |
|
return res; |
|
} |
|
|
|
FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection* out) const |
|
{ |
|
LanguageModelChartState *ret = new LanguageModelChartState(hypo, featureID, GetNGramOrder()); |
|
|
|
vector<const Word*> contextFactor; |
|
contextFactor.reserve(GetNGramOrder()); |
|
|
|
|
|
FFState *lmState = NewState( GetNullContextState() ); |
|
|
|
|
|
float prefixScore = 0.0; |
|
float finalizedScore = 0.0; |
|
|
|
|
|
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = |
|
hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap(); |
|
|
|
|
|
for (size_t phrasePos = 0, wordPos = 0; |
|
phrasePos < hypo.GetCurrTargetPhrase().GetSize(); |
|
phrasePos++) { |
|
|
|
const Word &word = hypo.GetCurrTargetPhrase().GetWord(phrasePos); |
|
|
|
|
|
if (!word.IsNonTerminal()) { |
|
ShiftOrPush(contextFactor, word); |
|
|
|
|
|
if (word == GetSentenceStartWord()) { |
|
UTIL_THROW_IF2(phrasePos != 0, |
|
"Sentence start symbol must be at the beginning of sentence"); |
|
delete lmState; |
|
lmState = NewState( GetBeginSentenceState() ); |
|
} |
|
|
|
else { |
|
updateChartScore( &prefixScore, &finalizedScore, GetValueGivenState(contextFactor, *lmState).score, ++wordPos ); |
|
} |
|
} |
|
|
|
|
|
else { |
|
|
|
size_t nonTermIndex = nonTermIndexMap[phrasePos]; |
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndex); |
|
|
|
const LanguageModelChartState* prevState = |
|
static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID)); |
|
|
|
size_t subPhraseLength = prevState->GetNumTargetTerminals(); |
|
|
|
|
|
if (phrasePos == 0) { |
|
|
|
|
|
prefixScore = prevState->GetPrefixScore(); |
|
finalizedScore = -prefixScore; |
|
|
|
|
|
delete lmState; |
|
lmState = NewState( prevState->GetRightContext() ); |
|
|
|
|
|
int suffixPos = prevState->GetSuffix().GetSize() - (GetNGramOrder()-1); |
|
if (suffixPos < 0) suffixPos = 0; |
|
for(; (size_t)suffixPos < prevState->GetSuffix().GetSize(); suffixPos++) { |
|
const Word &word = prevState->GetSuffix().GetWord(suffixPos); |
|
ShiftOrPush(contextFactor, word); |
|
wordPos++; |
|
} |
|
} |
|
|
|
|
|
else { |
|
|
|
for(size_t prefixPos = 0; |
|
prefixPos < GetNGramOrder()-1 |
|
&& prefixPos < subPhraseLength; |
|
prefixPos++) { |
|
const Word &word = prevState->GetPrefix().GetWord(prefixPos); |
|
ShiftOrPush(contextFactor, word); |
|
updateChartScore( &prefixScore, &finalizedScore, GetValueGivenState(contextFactor, *lmState).score, ++wordPos ); |
|
} |
|
|
|
finalizedScore -= prevState->GetPrefixScore(); |
|
|
|
|
|
if (subPhraseLength > GetNGramOrder() - 1) { |
|
|
|
delete lmState; |
|
lmState = NewState( prevState->GetRightContext() ); |
|
|
|
|
|
size_t remainingWords = subPhraseLength - (GetNGramOrder()-1); |
|
if (remainingWords > GetNGramOrder()-1) { |
|
|
|
remainingWords = GetNGramOrder()-1; |
|
} |
|
for(size_t suffixPos = prevState->GetSuffix().GetSize() - remainingWords; |
|
suffixPos < prevState->GetSuffix().GetSize(); |
|
suffixPos++) { |
|
const Word &word = prevState->GetSuffix().GetWord(suffixPos); |
|
ShiftOrPush(contextFactor, word); |
|
} |
|
wordPos += subPhraseLength; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
if (OOVFeatureEnabled()) { |
|
vector<float> scores(2); |
|
scores[0] = prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0]; |
|
|
|
scores[1] = 0; |
|
out->PlusEquals(this, scores); |
|
} else { |
|
out->PlusEquals(this, prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0]); |
|
} |
|
|
|
ret->Set(prefixScore, lmState); |
|
return ret; |
|
} |
|
|
|
void LanguageModelImplementation::updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const |
|
{ |
|
if (wordPos < GetNGramOrder()) { |
|
*prefixScore += score; |
|
} else { |
|
*finalizedScore += score; |
|
} |
|
} |
|
|
|
} |
|
|