|
#pragma once |
|
|
|
#include "moses/DecodeGraph.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/Syntax/BoundedPriorityContainer.h" |
|
#include "moses/Syntax/CubeQueue.h" |
|
#include "moses/Syntax/F2S/DerivationWriter.h" |
|
#include "moses/Syntax/F2S/RuleMatcherCallback.h" |
|
#include "moses/Syntax/PHyperedge.h" |
|
#include "moses/Syntax/RuleTable.h" |
|
#include "moses/Syntax/RuleTableFF.h" |
|
#include "moses/Syntax/SHyperedgeBundle.h" |
|
#include "moses/Syntax/SVertex.h" |
|
#include "moses/Syntax/SVertexRecombinationEqualityPred.h" |
|
#include "moses/Syntax/SVertexRecombinationHasher.h" |
|
#include "moses/Syntax/SymbolEqualityPred.h" |
|
#include "moses/Syntax/SymbolHasher.h" |
|
|
|
#include "GlueRuleSynthesizer.h" |
|
#include "InputTreeBuilder.h" |
|
#include "RuleTrie.h" |
|
|
|
namespace Moses |
|
{ |
|
namespace Syntax |
|
{ |
|
namespace T2S |
|
{ |
|
|
|
template<typename RuleMatcher> |
|
Manager<RuleMatcher>::Manager(ttasksptr const& ttask) |
|
: Syntax::Manager(ttask) |
|
{ |
|
if (const TreeInput *p = dynamic_cast<const TreeInput*>(&m_source)) { |
|
|
|
InputTreeBuilder builder(options()->output.factor_order); |
|
builder.Build(*p, "Q", m_inputTree); |
|
} else { |
|
UTIL_THROW2("ERROR: T2S::Manager requires input to be a tree"); |
|
} |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::InitializeRuleMatchers() |
|
{ |
|
const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
|
for (std::size_t i = 0; i < ffs.size(); ++i) { |
|
RuleTableFF *ff = ffs[i]; |
|
|
|
|
|
|
|
|
|
const RuleTable *table = ff->GetTable(); |
|
assert(table); |
|
RuleTable *nonConstTable = const_cast<RuleTable*>(table); |
|
RuleTrie *trie = dynamic_cast<RuleTrie*>(nonConstTable); |
|
assert(trie); |
|
boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *trie)); |
|
m_ruleMatchers.push_back(p); |
|
} |
|
|
|
|
|
|
|
|
|
m_glueRuleTrie.reset(new RuleTrie(ffs[0])); |
|
boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *m_glueRuleTrie)); |
|
m_ruleMatchers.push_back(p); |
|
m_glueRuleMatcher = p.get(); |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::InitializeStacks() |
|
{ |
|
|
|
assert(!m_inputTree.nodes.empty()); |
|
|
|
for (std::vector<InputTree::Node>::const_iterator p = |
|
m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) { |
|
const InputTree::Node &node = *p; |
|
|
|
|
|
SVertexStack &stack = m_stackMap[&(node.pvertex)]; |
|
|
|
|
|
if (node.children.empty()) { |
|
boost::shared_ptr<SVertex> v(new SVertex()); |
|
v->best = 0; |
|
v->pvertex = &(node.pvertex); |
|
stack.push_back(v); |
|
} |
|
} |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::Decode() |
|
{ |
|
|
|
|
|
|
|
const std::size_t popLimit = this->options()->cube.pop_limit; |
|
const std::size_t ruleLimit = this->options()->syntax.rule_limit; |
|
const std::size_t stackLimit = this->options()->search.stack_size; |
|
|
|
|
|
InitializeStacks(); |
|
|
|
|
|
InitializeRuleMatchers(); |
|
|
|
|
|
F2S::RuleMatcherCallback callback(m_stackMap, ruleLimit); |
|
|
|
|
|
Word dflt_nonterm = options()->syntax.output_default_non_terminal; |
|
GlueRuleSynthesizer glueRuleSynthesizer(*m_glueRuleTrie, dflt_nonterm); |
|
|
|
|
|
for (std::vector<InputTree::Node>::const_iterator p = |
|
m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) { |
|
|
|
const InputTree::Node &node = *p; |
|
|
|
|
|
if (node.children.empty()) { |
|
continue; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
callback.ClearContainer(); |
|
for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator |
|
q = m_ruleMatchers.begin(); q != m_ruleMatchers.end(); ++q) { |
|
(*q)->EnumerateHyperedges(node, callback); |
|
} |
|
|
|
|
|
const BoundedPriorityContainer<SHyperedgeBundle> &bundles = |
|
callback.GetContainer(); |
|
|
|
|
|
|
|
if (bundles.Size() == 0) { |
|
glueRuleSynthesizer.SynthesizeRule(node); |
|
m_glueRuleMatcher->EnumerateHyperedges(node, callback); |
|
assert(bundles.Size() == 1); |
|
} |
|
|
|
|
|
|
|
CubeQueue cubeQueue(bundles.Begin(), bundles.End()); |
|
std::size_t count = 0; |
|
std::vector<SHyperedge*> buffer; |
|
while (count < popLimit && !cubeQueue.IsEmpty()) { |
|
SHyperedge *hyperedge = cubeQueue.Pop(); |
|
|
|
|
|
hyperedge->head->pvertex = &(node.pvertex); |
|
|
|
buffer.push_back(hyperedge); |
|
++count; |
|
} |
|
|
|
|
|
SVertexStack &stack = m_stackMap[&(node.pvertex)]; |
|
RecombineAndSort(buffer, stack); |
|
|
|
|
|
if (stackLimit > 0 && stack.size() > stackLimit) { |
|
stack.resize(stackLimit); |
|
} |
|
} |
|
} |
|
|
|
template<typename RuleMatcher> |
|
const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const |
|
{ |
|
const InputTree::Node &rootNode = m_inputTree.nodes.back(); |
|
F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex); |
|
assert(p != m_stackMap.end()); |
|
const SVertexStack &stack = p->second; |
|
assert(!stack.empty()); |
|
return stack[0]->best; |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::ExtractKBest( |
|
std::size_t k, |
|
std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, |
|
bool onlyDistinct) const |
|
{ |
|
kBestList.clear(); |
|
if (k == 0 || m_source.GetSize() == 0) { |
|
return; |
|
} |
|
|
|
|
|
const InputTree::Node &rootNode = m_inputTree.nodes.back(); |
|
F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex); |
|
assert(p != m_stackMap.end()); |
|
const SVertexStack &stack = p->second; |
|
assert(!stack.empty()); |
|
|
|
KBestExtractor extractor; |
|
|
|
if (!onlyDistinct) { |
|
|
|
extractor.Extract(stack, k, kBestList); |
|
return; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const std::size_t nBestFactor = this->options()->nbest.factor; |
|
std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; |
|
|
|
|
|
KBestExtractor::KBestVec bigList; |
|
bigList.reserve(numDerivations); |
|
extractor.Extract(stack, numDerivations, bigList); |
|
|
|
|
|
std::set<Phrase> distinct; |
|
for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); |
|
kBestList.size() < k && p != bigList.end(); ++p) { |
|
boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; |
|
Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); |
|
if (distinct.insert(translation).second) { |
|
kBestList.push_back(derivation); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::RecombineAndSort( |
|
const std::vector<SHyperedge*> &buffer, SVertexStack &stack) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
typedef boost::unordered_map<SVertex *, SVertex *, |
|
SVertexRecombinationHasher, |
|
SVertexRecombinationEqualityPred> Map; |
|
Map map; |
|
for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); |
|
p != buffer.end(); ++p) { |
|
SHyperedge *h = *p; |
|
SVertex *v = h->head; |
|
assert(v->best == h); |
|
assert(v->recombined.empty()); |
|
std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); |
|
if (result.second) { |
|
continue; |
|
} |
|
|
|
|
|
|
|
SVertex *storedVertex = result.first->second; |
|
if (h->label.futureScore > storedVertex->best->label.futureScore) { |
|
|
|
storedVertex->recombined.push_back(storedVertex->best); |
|
storedVertex->best = h; |
|
} else { |
|
storedVertex->recombined.push_back(h); |
|
} |
|
h->head->best = 0; |
|
delete h->head; |
|
h->head = storedVertex; |
|
} |
|
|
|
|
|
stack.clear(); |
|
stack.reserve(map.size()); |
|
for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { |
|
stack.push_back(boost::shared_ptr<SVertex>(p->first)); |
|
} |
|
|
|
|
|
std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::OutputDetailedTranslationReport( |
|
OutputCollector *collector) const |
|
{ |
|
const SHyperedge *best = GetBestSHyperedge(); |
|
if (best == NULL || collector == NULL) { |
|
return; |
|
} |
|
long translationId = m_source.GetTranslationId(); |
|
std::ostringstream out; |
|
F2S::DerivationWriter::Write(*best, translationId, out); |
|
collector->Write(translationId, out.str()); |
|
} |
|
|
|
} |
|
} |
|
} |
|
|