|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <algorithm> |
|
#include <vector> |
|
#include "ChartHypothesis.h" |
|
#include "RuleCubeItem.h" |
|
#include "ChartCell.h" |
|
#include "ChartManager.h" |
|
#include "TargetPhrase.h" |
|
#include "Phrase.h" |
|
#include "StaticData.h" |
|
#include "ChartTranslationOptions.h" |
|
#include "moses/FF/FFState.h" |
|
#include "moses/FF/StatefulFeatureFunction.h" |
|
#include "moses/FF/StatelessFeatureFunction.h" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
ChartHypothesis::ChartHypothesis(const ChartTranslationOptions &transOpt, |
|
const RuleCubeItem &item, |
|
ChartManager &manager) |
|
:m_transOpt(item.GetTranslationDimension().GetTranslationOption()) |
|
,m_currSourceWordsRange(transOpt.GetSourceWordsRange()) |
|
,m_ffStates(StatefulFeatureFunction::GetStatefulFeatureFunctions().size()) |
|
,m_arcList(NULL) |
|
,m_winningHypo(NULL) |
|
,m_manager(manager) |
|
,m_id(manager.GetNextHypoId()) |
|
{ |
|
|
|
const std::vector<HypothesisDimension> &childEntries = item.GetHypothesisDimensions(); |
|
m_prevHypos.reserve(childEntries.size()); |
|
std::vector<HypothesisDimension>::const_iterator iter; |
|
for (iter = childEntries.begin(); iter != childEntries.end(); ++iter) { |
|
m_prevHypos.push_back(iter->GetHypothesis()); |
|
} |
|
} |
|
|
|
|
|
|
|
ChartHypothesis::ChartHypothesis(const ChartHypothesis &pred, |
|
const ChartKBestExtractor & ) |
|
:m_currSourceWordsRange(pred.m_currSourceWordsRange) |
|
,m_totalScore(pred.m_totalScore) |
|
,m_arcList(NULL) |
|
,m_winningHypo(NULL) |
|
,m_manager(pred.m_manager) |
|
,m_id(pred.m_manager.GetNextHypoId()) |
|
{ |
|
|
|
m_prevHypos.push_back(&pred); |
|
} |
|
|
|
ChartHypothesis::~ChartHypothesis() |
|
{ |
|
|
|
for (unsigned i = 0; i < m_ffStates.size(); ++i) { |
|
delete m_ffStates[i]; |
|
} |
|
|
|
|
|
if (m_arcList) { |
|
ChartArcList::iterator iter; |
|
for (iter = m_arcList->begin() ; iter != m_arcList->end() ; ++iter) { |
|
ChartHypothesis *hypo = *iter; |
|
delete hypo; |
|
} |
|
m_arcList->clear(); |
|
|
|
delete m_arcList; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
void ChartHypothesis::GetOutputPhrase(Phrase &outPhrase) const |
|
{ |
|
FactorType placeholderFactor = StaticData::Instance().options()->input.placeholder_factor; |
|
|
|
for (size_t pos = 0; pos < GetCurrTargetPhrase().GetSize(); ++pos) { |
|
const Word &word = GetCurrTargetPhrase().GetWord(pos); |
|
if (word.IsNonTerminal()) { |
|
|
|
size_t nonTermInd = GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos]; |
|
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd]; |
|
prevHypo->GetOutputPhrase(outPhrase); |
|
} else { |
|
outPhrase.AddWord(word); |
|
|
|
if (placeholderFactor != NOT_FOUND) { |
|
std::set<size_t> sourcePosSet = GetCurrTargetPhrase().GetAlignTerm().GetAlignmentsForTarget(pos); |
|
if (sourcePosSet.size() == 1) { |
|
const std::vector<const Word*> *ruleSourceFromInputPath = GetTranslationOption().GetSourceRuleFromInputPath(); |
|
UTIL_THROW_IF2(ruleSourceFromInputPath == NULL, |
|
"No source rule"); |
|
|
|
size_t sourcePos = *sourcePosSet.begin(); |
|
const Word *sourceWord = ruleSourceFromInputPath->at(sourcePos); |
|
UTIL_THROW_IF2(sourceWord == NULL, |
|
"No source word"); |
|
const Factor *factor = sourceWord->GetFactor(placeholderFactor); |
|
if (factor) { |
|
outPhrase.Back()[0] = factor; |
|
} |
|
} |
|
} |
|
|
|
} |
|
} |
|
} |
|
|
|
|
|
Phrase ChartHypothesis::GetOutputPhrase() const |
|
{ |
|
Phrase outPhrase(ARRAY_SIZE_INCR); |
|
GetOutputPhrase(outPhrase); |
|
return outPhrase; |
|
} |
|
|
|
|
|
void ChartHypothesis::GetOutputPhrase(size_t leftRightMost, size_t numWords, Phrase &outPhrase) const |
|
{ |
|
const TargetPhrase &tp = GetCurrTargetPhrase(); |
|
|
|
size_t targetSize = tp.GetSize(); |
|
for (size_t i = 0; i < targetSize; ++i) { |
|
size_t pos; |
|
if (leftRightMost == 1) { |
|
pos = i; |
|
} else if (leftRightMost == 2) { |
|
pos = targetSize - i - 1; |
|
} else { |
|
abort(); |
|
} |
|
|
|
const Word &word = tp.GetWord(pos); |
|
|
|
if (word.IsNonTerminal()) { |
|
|
|
size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos]; |
|
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd]; |
|
prevHypo->GetOutputPhrase(outPhrase); |
|
} else { |
|
outPhrase.AddWord(word); |
|
} |
|
|
|
if (outPhrase.GetSize() >= numWords) { |
|
return; |
|
} |
|
} |
|
} |
|
|
|
|
|
void ChartHypothesis::EvaluateWhenApplied() |
|
{ |
|
const StaticData &staticData = StaticData::Instance(); |
|
|
|
|
|
|
|
const std::vector<const StatelessFeatureFunction*>& sfs = |
|
StatelessFeatureFunction::GetStatelessFeatureFunctions(); |
|
for (unsigned i = 0; i < sfs.size(); ++i) { |
|
if (! staticData.IsFeatureFunctionIgnored( *sfs[i] )) { |
|
sfs[i]->EvaluateWhenApplied(*this,&m_currScoreBreakdown); |
|
} |
|
} |
|
|
|
const std::vector<const StatefulFeatureFunction*>& ffs = |
|
StatefulFeatureFunction::GetStatefulFeatureFunctions(); |
|
for (unsigned i = 0; i < ffs.size(); ++i) { |
|
if (! staticData.IsFeatureFunctionIgnored( *ffs[i] )) { |
|
m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_currScoreBreakdown); |
|
} |
|
} |
|
|
|
|
|
m_totalScore = GetTranslationOption().GetScores().GetWeightedScore(); |
|
m_totalScore += m_currScoreBreakdown.GetWeightedScore(); |
|
|
|
|
|
for (std::vector<const ChartHypothesis*>::const_iterator iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter) { |
|
const ChartHypothesis &prevHypo = **iter; |
|
m_totalScore += prevHypo.GetFutureScore(); |
|
} |
|
} |
|
|
|
void ChartHypothesis::AddArc(ChartHypothesis *loserHypo) |
|
{ |
|
if (!m_arcList) { |
|
if (loserHypo->m_arcList) { |
|
|
|
this->m_arcList = loserHypo->m_arcList; |
|
loserHypo->m_arcList = 0; |
|
} else { |
|
this->m_arcList = new ChartArcList(); |
|
} |
|
} else { |
|
if (loserHypo->m_arcList) { |
|
|
|
size_t my_size = m_arcList->size(); |
|
size_t add_size = loserHypo->m_arcList->size(); |
|
this->m_arcList->resize(my_size + add_size, 0); |
|
std::memcpy(&(*m_arcList)[0] + my_size, &(*loserHypo->m_arcList)[0], add_size * sizeof(ChartHypothesis *)); |
|
delete loserHypo->m_arcList; |
|
loserHypo->m_arcList = 0; |
|
} else { |
|
|
|
|
|
} |
|
} |
|
m_arcList->push_back(loserHypo); |
|
} |
|
|
|
|
|
struct CompareChartHypothesisTotalScore { |
|
bool operator()(const ChartHypothesis* hypo1, const ChartHypothesis* hypo2) const { |
|
return hypo1->GetFutureScore() > hypo2->GetFutureScore(); |
|
} |
|
}; |
|
|
|
void ChartHypothesis::CleanupArcList() |
|
{ |
|
|
|
m_winningHypo = this; |
|
|
|
if (!m_arcList) return; |
|
|
|
|
|
|
|
|
|
|
|
AllOptions const& opts = *StaticData::Instance().options(); |
|
size_t nBestSize = opts.nbest.nbest_size; |
|
bool distinctNBest = (opts.nbest.only_distinct |
|
|| opts.mbr.enabled |
|
|| opts.output.NeedSearchGraph() |
|
|| !opts.output.SearchGraphHG.empty()); |
|
|
|
if (!distinctNBest && m_arcList->size() > nBestSize) { |
|
|
|
NTH_ELEMENT4(m_arcList->begin() |
|
, m_arcList->begin() + nBestSize - 1 |
|
, m_arcList->end() |
|
, CompareChartHypothesisTotalScore()); |
|
|
|
|
|
ChartArcList::iterator iter; |
|
for (iter = m_arcList->begin() + nBestSize ; iter != m_arcList->end() ; ++iter) { |
|
ChartHypothesis *arc = *iter; |
|
delete arc; |
|
} |
|
m_arcList->erase(m_arcList->begin() + nBestSize |
|
, m_arcList->end()); |
|
} |
|
|
|
|
|
ChartArcList::iterator iter = m_arcList->begin(); |
|
for (; iter != m_arcList->end() ; ++iter) { |
|
ChartHypothesis *arc = *iter; |
|
arc->SetWinningHypo(this); |
|
} |
|
|
|
|
|
} |
|
|
|
void ChartHypothesis::SetWinningHypo(const ChartHypothesis *hypo) |
|
{ |
|
m_winningHypo = hypo; |
|
} |
|
|
|
size_t ChartHypothesis::hash() const |
|
{ |
|
size_t seed = 0; |
|
|
|
|
|
for (size_t i = 0; i < m_ffStates.size(); ++i) { |
|
const FFState *state = m_ffStates[i]; |
|
size_t hash = state->hash(); |
|
boost::hash_combine(seed, hash); |
|
} |
|
return seed; |
|
|
|
} |
|
|
|
bool ChartHypothesis::operator==(const ChartHypothesis& other) const |
|
{ |
|
|
|
for (size_t i = 0; i < m_ffStates.size(); ++i) { |
|
const FFState &thisState = *m_ffStates[i]; |
|
const FFState &otherState = *other.m_ffStates[i]; |
|
if (thisState != otherState) { |
|
return false; |
|
} |
|
} |
|
return true; |
|
} |
|
|
|
TO_STRING_BODY(ChartHypothesis) |
|
|
|
|
|
std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo) |
|
{ |
|
|
|
out << hypo.GetId(); |
|
|
|
|
|
if (hypo.GetWinningHypothesis() != NULL && |
|
hypo.GetWinningHypothesis() != &hypo) { |
|
out << "->" << hypo.GetWinningHypothesis()->GetId(); |
|
} |
|
|
|
if (hypo.GetManager().options()->output.include_lhs_in_search_graph) { |
|
out << " " << hypo.GetTargetLHS() << "=>"; |
|
} |
|
out << " " << hypo.GetCurrTargetPhrase() |
|
|
|
<< " " << hypo.GetCurrSourceRange(); |
|
|
|
HypoList::const_iterator iter; |
|
for (iter = hypo.GetPrevHypos().begin(); iter != hypo.GetPrevHypos().end(); ++iter) { |
|
const ChartHypothesis &prevHypo = **iter; |
|
out << " " << prevHypo.GetId(); |
|
} |
|
|
|
out << " [total=" << hypo.GetFutureScore() << "]"; |
|
out << " " << hypo.GetScoreBreakdown(); |
|
|
|
|
|
|
|
return out; |
|
} |
|
|
|
} |
|
|