|
|
|
#pragma once |
|
|
|
#include <string> |
|
#include <vector> |
|
#include "SingleFactor.h" |
|
#include <boost/thread/tss.hpp> |
|
#include "lm/model.hh" |
|
#include "moses/LM/Ken.h" |
|
#include "moses/FF/FFState.h" |
|
|
|
namespace Moses |
|
{ |
|
|
|
struct InMemoryPerSentenceOnDemandLMState : public FFState { |
|
lm::ngram::State state; |
|
virtual size_t hash() const { |
|
size_t ret = hash_value(state); |
|
return ret; |
|
} |
|
virtual bool operator==(const FFState& o) const { |
|
const InMemoryPerSentenceOnDemandLMState &other = static_cast<const InMemoryPerSentenceOnDemandLMState &>(o); |
|
bool ret = state == other.state; |
|
return ret; |
|
} |
|
|
|
}; |
|
|
|
class InMemoryPerSentenceOnDemandLM : public LanguageModel |
|
{ |
|
public: |
|
InMemoryPerSentenceOnDemandLM(const std::string &line); |
|
~InMemoryPerSentenceOnDemandLM(); |
|
|
|
void InitializeForInput(ttasksptr const& ttask); |
|
|
|
virtual void SetParameter(const std::string& key, const std::string& value) { |
|
GetPerThreadLM().SetParameter(key, value); |
|
} |
|
|
|
virtual const FFState* EmptyHypothesisState(const InputType &input) const { |
|
if (isInitialized()) { |
|
return GetPerThreadLM().EmptyHypothesisState(input); |
|
} else { |
|
return new InMemoryPerSentenceOnDemandLMState(); |
|
} |
|
} |
|
|
|
virtual FFState *EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const { |
|
if (isInitialized()) { |
|
return GetPerThreadLM().EvaluateWhenApplied(hypo, ps, out); |
|
} else { |
|
UTIL_THROW(util::Exception, "Can't evaluate an uninitialized LM\n"); |
|
} |
|
} |
|
|
|
virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const { |
|
if (isInitialized()) { |
|
return GetPerThreadLM().EvaluateWhenApplied(cur_hypo, featureID, accumulator); |
|
} else { |
|
UTIL_THROW(util::Exception, "Can't evaluate an uninitialized LM\n"); |
|
} |
|
} |
|
|
|
virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const { |
|
if (isInitialized()) { |
|
return GetPerThreadLM().EvaluateWhenApplied(hyperedge, featureID, accumulator); |
|
} else { |
|
UTIL_THROW(util::Exception, "Can't evaluate an uninitialized LM\n"); |
|
} |
|
} |
|
|
|
|
|
virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, std::size_t &oovCount) const { |
|
if (isInitialized()) { |
|
GetPerThreadLM().CalcScore(phrase, fullScore, ngramScore, oovCount); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::CalcScore called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void CalcScoreFromCache(const Phrase &phrase, float &fullScore, float &ngramScore, std::size_t &oovCount) const { |
|
if (isInitialized()) { |
|
GetPerThreadLM().CalcScoreFromCache(phrase, fullScore, ngramScore, oovCount); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::CalcScoreFromCache called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void IssueRequestsFor(Hypothesis& hypo, const FFState* input_state) { |
|
if (isInitialized()) { |
|
GetPerThreadLM().IssueRequestsFor(hypo, input_state); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::IssueRequestsFor called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void sync() { |
|
if (isInitialized()) { |
|
GetPerThreadLM().sync(); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::sync called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void SetFFStateIdx(int state_idx) { |
|
if (isInitialized()) { |
|
GetPerThreadLM().SetFFStateIdx(state_idx); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::SetFFStateIdx called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void IncrementalCallback(Incremental::Manager &manager) const { |
|
if (isInitialized()) { |
|
GetPerThreadLM().IncrementalCallback(manager); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::IncrementalCallback called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const { |
|
if (isInitialized()) { |
|
GetPerThreadLM().ReportHistoryOrder(out, phrase); |
|
} else { |
|
UTIL_THROW(util::Exception, "WARNING: InMemoryPerSentenceOnDemand::ReportHistoryOrder called prior to being initialized"); |
|
} |
|
} |
|
|
|
virtual void EvaluateInIsolation(const Phrase &source |
|
, const TargetPhrase &targetPhrase |
|
, ScoreComponentCollection &scoreBreakdown |
|
, ScoreComponentCollection &estimatedScores) const { |
|
if (isInitialized()) { |
|
GetPerThreadLM().EvaluateInIsolation(source, targetPhrase, scoreBreakdown, estimatedScores); |
|
} else { |
|
|
|
} |
|
} |
|
|
|
bool IsUseable(const FactorMask &mask) const { |
|
bool ret = mask[m_factorType]; |
|
return ret; |
|
} |
|
|
|
|
|
protected: |
|
LanguageModelKen<lm::ngram::ProbingModel> & GetPerThreadLM() const; |
|
|
|
mutable boost::thread_specific_ptr<LanguageModelKen<lm::ngram::ProbingModel> > m_perThreadLM; |
|
mutable boost::thread_specific_ptr<std::string> m_tmpFilename; |
|
|
|
FactorType m_factorType; |
|
|
|
bool isInitialized() const { |
|
if (m_tmpFilename.get() == NULL) { |
|
return false; |
|
} else { |
|
return true; |
|
} |
|
} |
|
|
|
}; |
|
|
|
|
|
} |
|
|