|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <cstring> |
|
#include <iostream> |
|
#include <memory> |
|
#include <cstdlib> |
|
#include <boost/shared_ptr.hpp> |
|
#include <boost/lexical_cast.hpp> |
|
|
|
#include "lm/binary_format.hh" |
|
#include "lm/enumerate_vocab.hh" |
|
#include "lm/left.hh" |
|
#include "lm/model.hh" |
|
#include "util/exception.hh" |
|
#include "util/tokenize_piece.hh" |
|
#include "util/string_stream.hh" |
|
|
|
#include "Ken.h" |
|
#include "Base.h" |
|
#include "moses/FF/FFState.h" |
|
#include "moses/TypeDef.h" |
|
#include "moses/Util.h" |
|
#include "moses/FactorCollection.h" |
|
#include "moses/Phrase.h" |
|
#include "moses/InputFileStream.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/ChartHypothesis.h" |
|
#include "moses/Incremental.h" |
|
#include "moses/Syntax/SHyperedge.h" |
|
#include "moses/Syntax/SVertex.h" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
namespace |
|
{ |
|
|
|
struct KenLMState : public FFState { |
|
lm::ngram::State state; |
|
virtual size_t hash() const { |
|
size_t ret = hash_value(state); |
|
return ret; |
|
} |
|
virtual bool operator==(const FFState& o) const { |
|
const KenLMState &other = static_cast<const KenLMState &>(o); |
|
bool ret = state == other.state; |
|
return ret; |
|
} |
|
|
|
}; |
|
|
|
class MappingBuilder : public lm::EnumerateVocab |
|
{ |
|
public: |
|
MappingBuilder(FactorCollection &factorCollection, std::vector<lm::WordIndex> &mapping) |
|
: m_factorCollection(factorCollection), m_mapping(mapping) {} |
|
|
|
void Add(lm::WordIndex index, const StringPiece &str) { |
|
std::size_t factorId = m_factorCollection.AddFactor(str)->GetId(); |
|
if (m_mapping.size() <= factorId) { |
|
|
|
m_mapping.resize(factorId + 1); |
|
} |
|
m_mapping[factorId] = index; |
|
} |
|
|
|
private: |
|
FactorCollection &m_factorCollection; |
|
std::vector<lm::WordIndex> &m_mapping; |
|
}; |
|
|
|
} |
|
|
|
template <class Model> void LanguageModelKen<Model>::LoadModel(const std::string &file, util::LoadMethod load_method) |
|
{ |
|
m_lmIdLookup.clear(); |
|
|
|
lm::ngram::Config config; |
|
if(this->m_verbosity >= 1) { |
|
config.messages = &std::cerr; |
|
} else { |
|
config.messages = NULL; |
|
} |
|
FactorCollection &collection = FactorCollection::Instance(); |
|
MappingBuilder builder(collection, m_lmIdLookup); |
|
config.enumerate_vocab = &builder; |
|
config.load_method = load_method; |
|
|
|
m_ngram.reset(new Model(file.c_str(), config)); |
|
VERBOSE(2, "LanguageModelKen " << m_description << " reset to " << file << "\n"); |
|
} |
|
|
|
template <class Model> LanguageModelKen<Model>::LanguageModelKen(const std::string &line, const std::string &file, FactorType factorType, util::LoadMethod load_method) |
|
:LanguageModel(line) |
|
,m_beginSentenceFactor(FactorCollection::Instance().AddFactor(BOS_)) |
|
,m_factorType(factorType) |
|
{ |
|
ReadParameters(); |
|
LoadModel(file, load_method); |
|
} |
|
|
|
template <class Model> LanguageModelKen<Model>::LanguageModelKen() |
|
:LanguageModel("KENLM") |
|
,m_beginSentenceFactor(FactorCollection::Instance().AddFactor(BOS_)) |
|
,m_factorType(0) |
|
{ |
|
ReadParameters(); |
|
} |
|
|
|
|
|
template <class Model> LanguageModelKen<Model>::LanguageModelKen(const LanguageModelKen<Model> ©_from) |
|
:LanguageModel(copy_from.GetArgLine()), |
|
m_ngram(copy_from.m_ngram), |
|
|
|
m_beginSentenceFactor(copy_from.m_beginSentenceFactor), |
|
m_factorType(copy_from.m_factorType), |
|
m_lmIdLookup(copy_from.m_lmIdLookup) |
|
{ |
|
} |
|
|
|
template <class Model> const FFState * LanguageModelKen<Model>::EmptyHypothesisState(const InputType &) const |
|
{ |
|
KenLMState *ret = new KenLMState(); |
|
ret->state = m_ngram->BeginSentenceState(); |
|
return ret; |
|
} |
|
|
|
template <class Model> void LanguageModelKen<Model>::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const |
|
{ |
|
fullScore = 0; |
|
ngramScore = 0; |
|
oovCount = 0; |
|
|
|
if (!phrase.GetSize()) return; |
|
|
|
lm::ngram::ChartState discarded_sadly; |
|
lm::ngram::RuleScore<Model> scorer(*m_ngram, discarded_sadly); |
|
|
|
size_t position; |
|
if (m_beginSentenceFactor == phrase.GetWord(0).GetFactor(m_factorType)) { |
|
scorer.BeginSentence(); |
|
position = 1; |
|
} else { |
|
position = 0; |
|
} |
|
|
|
size_t ngramBoundary = m_ngram->Order() - 1; |
|
|
|
size_t end_loop = std::min(ngramBoundary, phrase.GetSize()); |
|
for (; position < end_loop; ++position) { |
|
const Word &word = phrase.GetWord(position); |
|
if (word.IsNonTerminal()) { |
|
fullScore += scorer.Finish(); |
|
scorer.Reset(); |
|
} else { |
|
lm::WordIndex index = TranslateID(word); |
|
scorer.Terminal(index); |
|
if (!index) ++oovCount; |
|
} |
|
} |
|
float before_boundary = fullScore + scorer.Finish(); |
|
for (; position < phrase.GetSize(); ++position) { |
|
const Word &word = phrase.GetWord(position); |
|
if (word.IsNonTerminal()) { |
|
fullScore += scorer.Finish(); |
|
scorer.Reset(); |
|
} else { |
|
lm::WordIndex index = TranslateID(word); |
|
scorer.Terminal(index); |
|
if (!index) ++oovCount; |
|
} |
|
} |
|
fullScore += scorer.Finish(); |
|
|
|
ngramScore = TransformLMScore(fullScore - before_boundary); |
|
fullScore = TransformLMScore(fullScore); |
|
} |
|
|
|
template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const |
|
{ |
|
const lm::ngram::State &in_state = static_cast<const KenLMState&>(*ps).state; |
|
|
|
std::auto_ptr<KenLMState> ret(new KenLMState()); |
|
|
|
if (!hypo.GetCurrTargetLength()) { |
|
ret->state = in_state; |
|
return ret.release(); |
|
} |
|
|
|
const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos(); |
|
|
|
const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1; |
|
const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1); |
|
|
|
std::size_t position = begin; |
|
typename Model::State aux_state; |
|
typename Model::State *state0 = &ret->state, *state1 = &aux_state; |
|
|
|
float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)), *state0); |
|
++position; |
|
for (; position < adjust_end; ++position) { |
|
score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)), *state1); |
|
std::swap(state0, state1); |
|
} |
|
|
|
if (hypo.IsSourceCompleted()) { |
|
|
|
std::vector<lm::WordIndex> indices(m_ngram->Order() - 1); |
|
const lm::WordIndex *last = LastIDs(hypo, &indices.front()); |
|
score += m_ngram->FullScoreForgotState(&indices.front(), last, m_ngram->GetVocabulary().EndSentence(), ret->state).prob; |
|
} else if (adjust_end < end) { |
|
|
|
std::vector<lm::WordIndex> indices(m_ngram->Order() - 1); |
|
const lm::WordIndex *last = LastIDs(hypo, &indices.front()); |
|
m_ngram->GetState(&indices.front(), last, ret->state); |
|
} else if (state0 != &ret->state) { |
|
|
|
ret->state = *state0; |
|
} |
|
|
|
score = TransformLMScore(score); |
|
|
|
if (OOVFeatureEnabled()) { |
|
std::vector<float> scores(2); |
|
scores[0] = score; |
|
scores[1] = 0.0; |
|
out->PlusEquals(this, scores); |
|
} else { |
|
out->PlusEquals(this, score); |
|
} |
|
|
|
return ret.release(); |
|
} |
|
|
|
class LanguageModelChartStateKenLM : public FFState |
|
{ |
|
public: |
|
LanguageModelChartStateKenLM() {} |
|
|
|
const lm::ngram::ChartState &GetChartState() const { |
|
return m_state; |
|
} |
|
lm::ngram::ChartState &GetChartState() { |
|
return m_state; |
|
} |
|
|
|
size_t hash() const { |
|
size_t ret = hash_value(m_state); |
|
return ret; |
|
} |
|
virtual bool operator==(const FFState& o) const { |
|
const LanguageModelChartStateKenLM &other = static_cast<const LanguageModelChartStateKenLM &>(o); |
|
bool ret = m_state == other.m_state; |
|
return ret; |
|
} |
|
|
|
private: |
|
lm::ngram::ChartState m_state; |
|
}; |
|
|
|
template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *accumulator) const |
|
{ |
|
LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM(); |
|
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState()); |
|
const TargetPhrase &target = hypo.GetCurrTargetPhrase(); |
|
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = |
|
target.GetAlignNonTerm().GetNonTermIndexMap(); |
|
|
|
const size_t size = hypo.GetCurrTargetPhrase().GetSize(); |
|
size_t phrasePos = 0; |
|
|
|
if (size) { |
|
const Word &word = hypo.GetCurrTargetPhrase().GetWord(0); |
|
if (word.GetFactor(m_factorType) == m_beginSentenceFactor) { |
|
|
|
ruleScore.BeginSentence(); |
|
phrasePos++; |
|
} else if (word.IsNonTerminal()) { |
|
|
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]); |
|
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState(); |
|
ruleScore.BeginNonTerminal(prevState); |
|
phrasePos++; |
|
} |
|
} |
|
|
|
for (; phrasePos < size; phrasePos++) { |
|
const Word &word = hypo.GetCurrTargetPhrase().GetWord(phrasePos); |
|
if (word.IsNonTerminal()) { |
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]); |
|
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState(); |
|
ruleScore.NonTerminal(prevState); |
|
} else { |
|
ruleScore.Terminal(TranslateID(word)); |
|
} |
|
} |
|
|
|
float score = ruleScore.Finish(); |
|
score = TransformLMScore(score); |
|
score -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0]; |
|
|
|
if (OOVFeatureEnabled()) { |
|
std::vector<float> scores(2); |
|
scores[0] = score; |
|
scores[1] = 0.0; |
|
accumulator->PlusEquals(this, scores); |
|
} else { |
|
accumulator->PlusEquals(this, score); |
|
} |
|
return newState; |
|
} |
|
|
|
template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const |
|
{ |
|
LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM(); |
|
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState()); |
|
const TargetPhrase &target = *hyperedge.label.translation; |
|
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = |
|
target.GetAlignNonTerm().GetNonTermIndexMap2(); |
|
|
|
const size_t size = target.GetSize(); |
|
size_t phrasePos = 0; |
|
|
|
if (size) { |
|
const Word &word = target.GetWord(0); |
|
if (word.GetFactor(m_factorType) == m_beginSentenceFactor) { |
|
|
|
ruleScore.BeginSentence(); |
|
phrasePos++; |
|
} else if (word.IsNonTerminal()) { |
|
|
|
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]]; |
|
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->states[featureID])->GetChartState(); |
|
ruleScore.BeginNonTerminal(prevState); |
|
phrasePos++; |
|
} |
|
} |
|
|
|
for (; phrasePos < size; phrasePos++) { |
|
const Word &word = target.GetWord(phrasePos); |
|
if (word.IsNonTerminal()) { |
|
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]]; |
|
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->states[featureID])->GetChartState(); |
|
ruleScore.NonTerminal(prevState); |
|
} else { |
|
ruleScore.Terminal(TranslateID(word)); |
|
} |
|
} |
|
|
|
float score = ruleScore.Finish(); |
|
score = TransformLMScore(score); |
|
score -= target.GetScoreBreakdown().GetScoresForProducer(this)[0]; |
|
|
|
if (OOVFeatureEnabled()) { |
|
std::vector<float> scores(2); |
|
scores[0] = score; |
|
scores[1] = 0.0; |
|
accumulator->PlusEquals(this, scores); |
|
} else { |
|
accumulator->PlusEquals(this, score); |
|
} |
|
return newState; |
|
} |
|
|
|
template <class Model> void LanguageModelKen<Model>::IncrementalCallback(Incremental::Manager &manager) const |
|
{ |
|
manager.LMCallback(*m_ngram, m_lmIdLookup); |
|
} |
|
|
|
template <class Model> void LanguageModelKen<Model>::ReportHistoryOrder(std::ostream &out, const Phrase &phrase) const |
|
{ |
|
out << "|lm=("; |
|
if (!phrase.GetSize()) return; |
|
|
|
typename Model::State aux_state; |
|
typename Model::State start_of_sentence_state = m_ngram->BeginSentenceState(); |
|
typename Model::State *state0 = &start_of_sentence_state; |
|
typename Model::State *state1 = &aux_state; |
|
|
|
for (std::size_t position=0; position<phrase.GetSize(); position++) { |
|
const lm::WordIndex idx = TranslateID(phrase.GetWord(position)); |
|
lm::FullScoreReturn ret(m_ngram->FullScore(*state0, idx, *state1)); |
|
if (position) out << ","; |
|
out << (int) ret.ngram_length << ":" << TransformLMScore(ret.prob); |
|
if (idx == 0) out << ":unk"; |
|
std::swap(state0, state1); |
|
} |
|
out << ")| "; |
|
} |
|
|
|
template <class Model> |
|
bool LanguageModelKen<Model>::IsUseable(const FactorMask &mask) const |
|
{ |
|
bool ret = mask[m_factorType]; |
|
return ret; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template class LanguageModelKen<lm::ngram::ProbingModel>; |
|
template class LanguageModelKen<lm::ngram::RestProbingModel>; |
|
template class LanguageModelKen<lm::ngram::TrieModel>; |
|
template class LanguageModelKen<lm::ngram::ArrayTrieModel>; |
|
template class LanguageModelKen<lm::ngram::QuantTrieModel>; |
|
template class LanguageModelKen<lm::ngram::QuantArrayTrieModel>; |
|
|
|
|
|
LanguageModel *ConstructKenLM(const std::string &lineOrig) |
|
{ |
|
FactorType factorType = 0; |
|
string filePath; |
|
util::LoadMethod load_method = util::POPULATE_OR_READ; |
|
|
|
util::TokenIter<util::SingleCharacter, true> argument(lineOrig, ' '); |
|
++argument; |
|
|
|
util::StringStream line; |
|
line << "KENLM"; |
|
|
|
for (; argument; ++argument) { |
|
const char *equals = std::find(argument->data(), argument->data() + argument->size(), '='); |
|
UTIL_THROW_IF2(equals == argument->data() + argument->size(), |
|
"Expected = in KenLM argument " << *argument); |
|
StringPiece name(argument->data(), equals - argument->data()); |
|
StringPiece value(equals + 1, argument->data() + argument->size() - equals - 1); |
|
if (name == "factor") { |
|
factorType = boost::lexical_cast<FactorType>(value); |
|
} else if (name == "order") { |
|
|
|
} else if (name == "path") { |
|
filePath.assign(value.data(), value.size()); |
|
} else if (name == "lazyken") { |
|
|
|
if (value == "0" || value == "false") { |
|
load_method = util::POPULATE_OR_READ; |
|
} else if (value == "1" || value == "true") { |
|
load_method = util::LAZY; |
|
} else { |
|
UTIL_THROW2("Can't parse lazyken argument " << value << ". Also, lazyken is deprecated. Use load with one of the arguments lazy, populate_or_lazy, populate_or_read, read, or parallel_read."); |
|
} |
|
} else if (name == "load") { |
|
if (value == "lazy") { |
|
load_method = util::LAZY; |
|
} else if (value == "populate_or_lazy") { |
|
load_method = util::POPULATE_OR_LAZY; |
|
} else if (value == "populate_or_read" || value == "populate") { |
|
load_method = util::POPULATE_OR_READ; |
|
} else if (value == "read") { |
|
load_method = util::READ; |
|
} else if (value == "parallel_read") { |
|
load_method = util::PARALLEL_READ; |
|
} else { |
|
UTIL_THROW2("Unknown KenLM load method " << value); |
|
} |
|
} else { |
|
|
|
line << " " << name << "=" << value; |
|
} |
|
} |
|
|
|
return ConstructKenLM(line.str(), filePath, factorType, load_method); |
|
} |
|
|
|
LanguageModel *ConstructKenLM(const std::string &line, const std::string &file, FactorType factorType, util::LoadMethod load_method) |
|
{ |
|
lm::ngram::ModelType model_type; |
|
if (lm::ngram::RecognizeBinary(file.c_str(), model_type)) { |
|
switch(model_type) { |
|
case lm::ngram::PROBING: |
|
return new LanguageModelKen<lm::ngram::ProbingModel>(line, file, factorType, load_method); |
|
case lm::ngram::REST_PROBING: |
|
return new LanguageModelKen<lm::ngram::RestProbingModel>(line, file, factorType, load_method); |
|
case lm::ngram::TRIE: |
|
return new LanguageModelKen<lm::ngram::TrieModel>(line, file, factorType, load_method); |
|
case lm::ngram::QUANT_TRIE: |
|
return new LanguageModelKen<lm::ngram::QuantTrieModel>(line, file, factorType, load_method); |
|
case lm::ngram::ARRAY_TRIE: |
|
return new LanguageModelKen<lm::ngram::ArrayTrieModel>(line, file, factorType, load_method); |
|
case lm::ngram::QUANT_ARRAY_TRIE: |
|
return new LanguageModelKen<lm::ngram::QuantArrayTrieModel>(line, file, factorType, load_method); |
|
default: |
|
UTIL_THROW2("Unrecognized kenlm model type " << model_type); |
|
} |
|
} else { |
|
return new LanguageModelKen<lm::ngram::ProbingModel>(line, file, factorType, load_method); |
|
} |
|
} |
|
|
|
} |
|
|