|
#pragma once |
|
|
|
#include "moses/FF/FFState.h" |
|
#include "moses/ChartHypothesis.h" |
|
#include "moses/ChartManager.h" |
|
|
|
namespace Moses |
|
{ |
|
|
|
class LanguageModelChartState : public FFState |
|
{ |
|
private: |
|
float m_prefixScore; |
|
FFState* m_lmRightContext; |
|
|
|
Phrase m_contextPrefix, m_contextSuffix; |
|
|
|
size_t m_numTargetTerminals; |
|
|
|
const ChartHypothesis &m_hypo; |
|
|
|
|
|
|
|
|
|
|
|
size_t CalcPrefix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const { |
|
const TargetPhrase &target = hypo.GetCurrTargetPhrase(); |
|
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = |
|
target.GetAlignNonTerm().GetNonTermIndexMap(); |
|
|
|
|
|
for (size_t pos = 0; pos < target.GetSize(); ++pos) { |
|
const Word &word = target.GetWord(pos); |
|
|
|
|
|
if (word.IsNonTerminal()) { |
|
size_t nonTermInd = nonTermIndexMap[pos]; |
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd); |
|
size = static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID))->CalcPrefix(*prevHypo, featureID, ret, size); |
|
} |
|
|
|
else { |
|
ret.AddWord(target.GetWord(pos)); |
|
size--; |
|
} |
|
|
|
|
|
if (size==0) |
|
break; |
|
} |
|
|
|
return size; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
size_t CalcSuffix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const { |
|
UTIL_THROW_IF2(m_contextPrefix.GetSize() > m_numTargetTerminals, "Error"); |
|
|
|
|
|
|
|
if (m_contextPrefix.GetSize() == m_numTargetTerminals) { |
|
size_t maxCount = std::min(m_contextPrefix.GetSize(), size); |
|
size_t pos= m_contextPrefix.GetSize() - 1; |
|
|
|
for (size_t ind = 0; ind < maxCount; ++ind) { |
|
const Word &word = m_contextPrefix.GetWord(pos); |
|
ret.PrependWord(word); |
|
--pos; |
|
} |
|
|
|
size -= maxCount; |
|
return size; |
|
} |
|
|
|
else { |
|
const TargetPhrase& target = hypo.GetCurrTargetPhrase(); |
|
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = |
|
target.GetAlignNonTerm().GetNonTermIndexMap(); |
|
for (int pos = (int) target.GetSize() - 1; pos >= 0 ; --pos) { |
|
const Word &word = target.GetWord(pos); |
|
|
|
if (word.IsNonTerminal()) { |
|
size_t nonTermInd = nonTermIndexMap[pos]; |
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd); |
|
size = static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID))->CalcSuffix(*prevHypo, featureID, ret, size); |
|
} else { |
|
ret.PrependWord(hypo.GetCurrTargetPhrase().GetWord(pos)); |
|
size--; |
|
} |
|
|
|
if (size==0) |
|
break; |
|
} |
|
|
|
return size; |
|
} |
|
} |
|
|
|
|
|
public: |
|
LanguageModelChartState(const ChartHypothesis &hypo, int featureID, size_t order) |
|
:m_lmRightContext(NULL) |
|
,m_contextPrefix(order - 1) |
|
,m_contextSuffix( order - 1) |
|
,m_hypo(hypo) { |
|
m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals(); |
|
|
|
for (std::vector<const ChartHypothesis*>::const_iterator i = hypo.GetPrevHypos().begin(); i != hypo.GetPrevHypos().end(); ++i) { |
|
|
|
m_numTargetTerminals += static_cast<const LanguageModelChartState*>((*i)->GetFFState(featureID))->GetNumTargetTerminals(); |
|
} |
|
|
|
CalcPrefix(hypo, featureID, m_contextPrefix, order - 1); |
|
CalcSuffix(hypo, featureID, m_contextSuffix, order - 1); |
|
} |
|
|
|
~LanguageModelChartState() { |
|
delete m_lmRightContext; |
|
} |
|
|
|
void Set(float prefixScore, FFState *rightState) { |
|
m_prefixScore = prefixScore; |
|
m_lmRightContext = rightState; |
|
} |
|
|
|
float GetPrefixScore() const { |
|
return m_prefixScore; |
|
} |
|
FFState* GetRightContext() const { |
|
return m_lmRightContext; |
|
} |
|
|
|
size_t GetNumTargetTerminals() const { |
|
return m_numTargetTerminals; |
|
} |
|
|
|
const Phrase &GetPrefix() const { |
|
return m_contextPrefix; |
|
} |
|
const Phrase &GetSuffix() const { |
|
return m_contextSuffix; |
|
} |
|
|
|
size_t hash() const { |
|
size_t ret; |
|
|
|
|
|
ret = m_hypo.GetCurrSourceRange().GetStartPos() > 0; |
|
if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) { |
|
size_t hash = hash_value(GetPrefix()); |
|
boost::hash_combine(ret, hash); |
|
} |
|
|
|
|
|
size_t inputSize = m_hypo.GetManager().GetSource().GetSize(); |
|
boost::hash_combine(ret, m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1); |
|
if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) { |
|
size_t hash = m_lmRightContext->hash(); |
|
boost::hash_combine(ret, hash); |
|
} |
|
|
|
return ret; |
|
} |
|
virtual bool operator==(const FFState& o) const { |
|
const LanguageModelChartState &other = |
|
static_cast<const LanguageModelChartState &>( o ); |
|
|
|
|
|
if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) { |
|
bool ret = GetPrefix() == other.GetPrefix(); |
|
if (ret == false) |
|
return false; |
|
} |
|
|
|
|
|
size_t inputSize = m_hypo.GetManager().GetSource().GetSize(); |
|
if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) { |
|
bool ret = (*other.GetRightContext()) == (*m_lmRightContext); |
|
return ret; |
|
} |
|
return true; |
|
} |
|
|
|
}; |
|
|
|
} |
|
|
|
|